diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
index 85bffc245..f628354cc 100644
--- a/.git-blame-ignore-revs
+++ b/.git-blame-ignore-revs
@@ -22,3 +22,9 @@
ef448d5271e7ce4a2de912dd2a76a629e7a569cb
# blacken with black version 22
b1d93c708b314444fecd1722438cec6f91f028d6
+# blacken with black version 23
+84af0c2b4c1f6912450b62f11b9065051ece4763
+1a5d1ca0be3938eb46975d97325bd25883370523
+fd7c998dfd0889aba3bf0c6ef93964d514404e15
+e2d5e28404e7ae218040cf0004992a125fb6bd65
+60d03fe82f6a4f3ff0a8c70fac150d1325d91b86
diff --git a/AUTHORS.rst b/AUTHORS.rst
index 3769f4853..d2d950ec4 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -20,7 +20,7 @@ Active developers are indicated by (*). Authors of the PINT paper are indicated
* Anne Archibald (#*)
* Matteo Bachetti (#)
* Bastian Beischer
-* Deven Bhakta
+* Deven Bhakta (*)
* Chloe Champagne (#)
* Jonathan Colen (#)
* Thankful Cromartie
@@ -28,6 +28,7 @@ Active developers are indicated by (*). Authors of the PINT paper are indicated
* Paul Demorest (#)
* Julia Deneva
* Justin Ellis
+* William Fiore (*)
* Fabian Jankowski
* Rick Jenet (#)
* Ross Jennings (#*)
diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md
new file mode 100644
index 000000000..a9e07480b
--- /dev/null
+++ b/CHANGELOG-unreleased.md
@@ -0,0 +1,27 @@
+# Changelog
+All notable changes to this project will be documented in this file.
+
+The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
+and this project, at least loosely, adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+
+This file contains the unreleased changes to the codebase. See CHANGELOG.md for
+the released changes.
+
+## Unreleased
+### Changed
+- Third-order Roemer delay terms to ELL1 model
+- Made the addition of a TZR TOA (`AbsPhase`) in the `TimingModel` explicit in `Residuals` class.
+- Updated `CONTRIBUTING.rst` with the latest information.
+- Made `TimingModel.params` and `TimingModel.ordered_params` identical. Deprecated `TimingModel.ordered_params`.
+### Added
+- Third-order Roemer delay terms to ELL1 model
+- Options to add a TZR TOA (`AbsPhase`) during the creation of a `TimingModel` using `ModelBuilder.__call__`, `get_model`, and `get_model_and_toas`
+- `pint.print_info()` function for bug reporting
+- Added an autocorrelation function to check for chain convergence in `event_optimize`
+### Fixed
+- Deleting JUMP1 from flag tables will not prevent fitting
+- Simulating TOAs from tim file when PLANET_SHAPIRO is true now works
+- Docstrings for `get_toas()` and `get_model_and_toas()`
+- Set `DelayComponent_list` and `NoiseComponent_list` to empty list if such components are absent
+- Fix invalid access of `PLANET_SHAPIRO` in models without `Astrometry`
+### Removed
diff --git a/CHANGELOG.md b/CHANGELOG.md
index b0b64a20d..57f52f546 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,7 +4,39 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project, at least loosely, adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
-## Unreleased
+This file contains the released changes to the codebase. See CHANGELOG-unreleased.md for
+the unreleased changes. This file should only be changed while tagging a new version.
+
+## [0.9.6] 2023-06-22
+### Changed
+- Applied `sourcery` refactors to the entire codebase
+- Changed threshold for `test_model_derivatives` test to avoid CI failures
+- Unreleased CHANGELOG entries should now be entered in `CHANGELOG-unreleased.md` instead of `CHANGELOG.md`. Updated documentation accordingly.
+- Changed tests to remove `unittest` and use pure pytest format
+- Changed deprecated `sampler.chain` usage
+- Download data automatically in the profiling script `high_level_benchmark.py` instead of silently giving wrong results.
+### Added
+- `SpindownBase` as the abstract base class for `Spindown` and `PeriodSpindown` in the `How_to_build_a_timing_model_component.py` example.
+- `SolarWindDispersionBase` as the abstract base class for solar wind dispersion components.
+- `validate_component_types` method for more rigorous validation of timing model components.
+- roundtrip test to make sure clock corrections are not written to tim files
+- `calc_phase_mean` and `calc_time_mean` methods in `Residuals` class to compute the residual mean.
+- `PhaseOffset` component (overall phase offset between physical and TZR toas)
+- `tzr` attribute in `TOAs` class to identify TZR TOAs
+- Documentation: Explanation for offsets
+- Example: `phase_offset_example.py`
+- method `AllComponents.param_to_unit` to get units for any parameter, and then made function `utils.get_unit`
+- can override/add parameter values when reading models
+- docs now include list of observatories along with google maps links and clock files
+### Fixed
+- fixed docstring for `add_param_from_top`
+- Gridded calculations now respect logger settings
+- Event TOAs now have default error that is non-zero, and can set as desired
+- Model conversion ICRS <-> ECL works if PM uncertainties are not set
+- Fix `merge_TOAs()` to allow lists of length 1
+### Removed
+
+## [0.9.5] 2023-05-01
### Changed
- Changed minimum supported version of `scipy` to 1.4.1
- Moved `DMconst` from `pint.models.dispersion_model` to `pint` to avoid circular imports
@@ -12,14 +44,16 @@ and this project, at least loosely, adheres to [Semantic Versioning](https://sem
- Refactor `Dre` method, fix expressions for Einstein delay and post-Keplerian parameters in DD model
- Updated contributor list (AUTHORS.rst)
- Emit an informative warning for "MODE" statement in TOA file; Ignore "MODE 1" silently
-- Version of `sphinx-rtd-theme` updated in `requirements_dev.txt`
+- Version of `sphinx-rtd-theme` updated in `requirements_dev.txt`
- Updated `black` version to 23.x
+- Older event loading functions now use newer functions to create TOAs and then convert to list of TOA objects
+- Limited hypothesis to <= 6.72.0 to avoid numpy problems in oldestdeps
### Added
- Documentation: Explanation for DM
- Methods to compute dispersion slope and to convert DM using the CODATA value of DMconst
- `TimingModel.total_dispersion_slope` method
- Explicit discussion of DT92 convention to DDK model
-- HESS and ORT telescopes to the list of known observatories
+- HAWC, HESS and ORT telescopes to the list of known observatories
- Documentation: making TOAs from array of times added to HowTo
- Method to make TOAs from an array of times
- Clock correction for LEAP
@@ -31,13 +65,25 @@ and this project, at least loosely, adheres to [Semantic Versioning](https://sem
- `funcParameters` defined as functions operating on other parameters
- Option to save `emcee` backend chains in `event_optimize`
- Documentation on how to extract a covariance matrix
+- DDS and DDGR models
+- Second-order corrections included in ELL1
+- Module for converting between binary models also included in `convert_parfile`
+- Method to get a parameter as a `uncertainties.ufloat` for doing math
+- Method to get current binary period and uncertainty at a given time regardless of binary model
+- TCB to TDB conversion on read, and conversion script (`tcb2tdb`)
+- Functions to get TOAs objects from satellite data (Fermi and otherwise)
+- Methods to convert a TOAs object into a list of TOA objects
### Fixed
+- Syntax error in README.rst
- Broken notebooks CI test
- BIPM correction for simulated TOAs
- Added try/except to `test_pldmnoise.py`/`test_PLRedNoise_recovery` to avoid exceptions during CI
- Import for `longdouble2str` in `get_tempo_result`
- Plotting orbital phase in `pintk` when FB0 is used instead of PB
+- Selection of BIPM for random models
- Added 1 sigma errors to update the postfit parfile errors in `event_optimize`
+- Fixed DDS CI testing failures
+- Add SolarSystemShapiro to the timing model only if an Astrometry component is present.
### Removed
## [0.9.3] 2022-12-16
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index c8d49c528..f20ad9483 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -20,11 +20,14 @@ Report bugs at https://github.com/nanograv/pint/issues.
If you are reporting a bug, please include:
-* Your operating system name and version.
-* The output of ``pint.__version__`` and ``pint.__file__``
+* The output of ``pint.print_info()``. This command provides the version information of
+ the OS, Python, PINT, and the various dependencies along with other information about
+ your system.
* Any details about your local setup that might be helpful in troubleshooting,
- such as the command used to install PINT and whether you are using a virtualenv.
-* Detailed steps to reproduce the bug, as simply as possible.
+ such as the command used to install PINT and whether you are using a virtualenv,
+ conda environment, etc.
+* Detailed steps to reproduce the bug, as simply as possible. A self-contained
+ code snippet that triggers the issue will be most helpful.
Submit Feedback
~~~~~~~~~~~~~~~
@@ -74,22 +77,37 @@ to write good documentation, you come to understand the code very well.
Get Started!
------------
-Ready to contribute? Here's how to set up PINT for local development.
+Ready to contribute? Here's how to set up `PINT` for local development.
-1. Fork_ the ``pint`` repo on GitHub.
+1. Fork_ the `PINT` repo on GitHub.
2. Clone your fork locally::
$ git clone git@github.com:your_name_here/pint.git
-3. Install your local copy into a virtualenv. Assuming you have
- virtualenvwrapper installed, this is how you set up your fork for local
+3. Install your local copy into a `conda`_ environment. Assuming you have
+ `conda` installed, this is how you set up your fork for local
development::
- $ mkvirtualenv pint
+ $ conda create -n pint-devel python=3.10
+ $ conda activate pint-devel
+ $ cd PINT/
+ $ conda install -c conda-forge --file requirements_dev.txt
+ $ conda install -c conda-forge --file requirements.txt
+ $ pip install -e .
+ $ pre-commit install
+
+ The last command installs pre-commit hooks which will squawk at you while trying
+ to commit changes that don't adhere to our `Coding Style`_.
+
+ Alternatively, this can also be done using `virtualenv`. Assuming you have
+ `virtualenvwrapper` installed, this is how you set up your fork for local
+ development::
+
+ $ mkvirtualenv pint-devel
$ cd PINT/
$ pip install -r requirements_dev.txt
- $ pip install -r requirements.txt
$ pip install -e .
+ $ pre-commit install
4. Create a branch for local development::
@@ -107,13 +125,13 @@ Ready to contribute? Here's how to set up PINT for local development.
6. Commit your changes and push your branch to GitHub::
$ git add .
- $ git commit -m "Your detailed description of your changes."
+ $ git commit -m "Detailed description of your changes."
$ git push origin name-of-your-bugfix-or-feature
7. Submit a pull request through the GitHub website.
-8. Check that our automatic testing "Travis CI" passes your code. If
- problems crop up, fix them, commit the changes, and push a new version,
+8. Check that our automatic testing in "GitHub Actions" passes for your code.
+ If problems crop up, fix them, commit the changes, and push a new version,
which will automatically update the pull request::
$ git add pint/file-i-just-fixed.py
@@ -125,13 +143,14 @@ Ready to contribute? Here's how to set up PINT for local development.
functional changes. If accepted, it will be merged into the master branch.
.. _Fork: https://help.github.com/en/articles/fork-a-repo
+.. _`conda`: https://docs.conda.io/
Pull Request Guidelines
-----------------------
Before you submit a pull request, check that it meets these guidelines:
-1. Try to write clear :ref:`pythonic` code, follow our :ref:`CodingStyle`, and think
+1. Try to write clear `Pythonic`_ code, follow our `Coding Style`_, and think
about how others might use your new code.
2. The pull request should include tests that cover both the expected
behavior and sensible error reporting when given bad input.
@@ -139,7 +158,13 @@ Before you submit a pull request, check that it meets these guidelines:
be updated. Put your new functionality into a function with a
docstring. Check the HTML documentation produced by ``make docs``
to make sure your new documentation appears and looks reasonable.
-4. The pull request should work for Python 2.7 and 3.6+. Check
- https://travis-ci.org/nanograv/pint/pull_requests
- and make sure that the tests pass for all supported Python versions.
-
+ If the new functionality needs a more detailed explanation than can be
+ put in a docstring, add it to ``docs/explanation.rst``. Make sure that
+ the docstring contains a brief description as well.
+4. The pull request should work for and 3.8+. Make sure that all the
+ CI tests for the pull request pass.
+5. Update ``CHANGELOG-unreleased.md`` with an appropriate entry. Please note
+ that ``CHANGELOG.md`` should not be updated for pull requests.
+
+.. _`Pythonic`: https://peps.python.org/pep-0008/
+.. _`Coding Style`: https://nanograv-pint.readthedocs.io/en/latest/coding-style.html
\ No newline at end of file
diff --git a/LICENSE.md b/LICENSE.md
index 078302e06..75233b867 100644
--- a/LICENSE.md
+++ b/LICENSE.md
@@ -1,4 +1,4 @@
-Copyright 2014-2017, PINT developers
+Copyright 2014-2023, PINT developers
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
diff --git a/README.rst b/README.rst
index fac6ed137..c65b18a24 100644
--- a/README.rst
+++ b/README.rst
@@ -36,16 +36,17 @@ PINT
PINT is not TEMPO3
------------------
-PINT is a project to develop a new pulsar timing solution based on
+PINT is a project to develop a pulsar timing solution based on
python and modern libraries. It is still in active development,
-but it can already produce residuals from most "normal"
+but it is in production use by the NANOGrav collaboration and
+it has been demonstrated produce residuals from most "normal"
timing models that agree with Tempo and Tempo2 to within ~10
nanoseconds. It can be used within python scripts or notebooks,
and there are several command line tools that come with it.
-The primary reasons we are developing PINT are:
+The primary reasons PINT was developed are:
-* To have a robust system to check high-precision timing results that is
+* To have a robust system to produce high-precision timing results that is
completely independent of TEMPO and Tempo2
* To make a system that is easy to extend and modify due to a good design
@@ -57,7 +58,13 @@ IMPORTANT Notes!
PINT has a naming conflict with the `pint `_ units package available from PyPI (i.e. using pip) and conda.
Do **NOT** ``pip install pint`` or ``conda install pint``! See below!
-PINT requires `longdouble` arithmetic within `numpy`, which is currently not supported natively on M1 Macs (e.g., with the `ARM64 conda build `_). So it may be better to install the standard `osx-64` build and rely on Rosetta.
+PINT requires ``longdouble`` (80- or 128-bit floating point) arithmetic within ``numpy``, which is currently not supported natively on M1/M2 Macs.
+However, you can use an x86 version of ``conda`` even on an M1/M2 Mac (which will run under Rosetta emulation):
+see `instructions for using Apple Intel packages on Apple
+silicon `_.
+It's possible to have `parallel versions of conda for x86 and
+ARM `_.
+
Installing
----------
@@ -96,19 +103,17 @@ it, ensuring that all dependencies needed to run PINT are available::
$ cd PINT
$ pip install .
-Complete installation instructions are available here_.
-
-.. _here: https://nanograv-pint.readthedocs.io/en/latest/installation.html
+Complete installation instructions are available on `readthedocs `_.
Using
-----
-See the online documentation_. Specifically:
+See the online documentation_. Specifically:
-* `tutorials `_
+* `Tutorials `_
* `API reference `_
-* `How to's for common tasks `_
+* `How-Tos for common tasks `_
Are you a NANOGrav member? Then join the #pint channel in the NANOGrav slack.
@@ -119,10 +124,13 @@ email pint@nanograv.org or one of the people below:
* Paul Ray (Paul.Ray@nrl.navy.mil)
* David Kaplan (kaplan@uwm.edu)
-Want to do something new? Submit a github `issue `_.
+Want to do something new? Submit a github `issue `_.
.. _documentation: http://nanograv-pint.readthedocs.io/en/latest/
And for more details, please read and cite(!) the PINT paper_.
.. _paper: https://ui.adsabs.harvard.edu/abs/2021ApJ...911...45L/abstract
+
+Articles that cite the PINT paper can be found in an ADS `Library `_.
+A list of software packages that use PINT can be found `here `_.
diff --git a/docs/_ext/componentlist.py b/docs/_ext/componentlist.py
index 8d9534d71..63bffd6c3 100644
--- a/docs/_ext/componentlist.py
+++ b/docs/_ext/componentlist.py
@@ -1,9 +1,6 @@
from docutils import nodes
from docutils.parsers.rst import Directive
-from docutils.parsers.rst.directives.tables import Table
-from docutils.parsers.rst.directives import unchanged_required
from docutils.statemachine import ViewList
-import pint.utils
class ComponentList(Directive):
diff --git a/docs/_ext/paramtable.py b/docs/_ext/paramtable.py
index a988b29df..67cde7b18 100644
--- a/docs/_ext/paramtable.py
+++ b/docs/_ext/paramtable.py
@@ -1,5 +1,4 @@
from docutils import nodes
-from docutils.parsers.rst import Directive
from docutils.parsers.rst.directives.tables import Table
from docutils.parsers.rst.directives import unchanged_required
from docutils.statemachine import ViewList
@@ -55,17 +54,14 @@ def run(self):
entry += para
elif c == "name":
text = d[c]
- alias_list = d.get("aliases", [])
- if alias_list:
+ if alias_list := d.get("aliases", []):
text += " / " + ", ".join(d["aliases"])
entry += nodes.paragraph(text=text)
elif isinstance(d[c], str):
entry += nodes.paragraph(text=d[c])
elif isinstance(d[c], list):
entry += nodes.paragraph(text=", ".join(d[c]))
- elif d[c] is None:
- pass
- else:
+ elif d[c] is not None:
entry += nodes.paragraph(text=str(d[c]))
tbody += row
tgroup += tbody
diff --git a/docs/_ext/sitetable.py b/docs/_ext/sitetable.py
new file mode 100644
index 000000000..eff3c022e
--- /dev/null
+++ b/docs/_ext/sitetable.py
@@ -0,0 +1,100 @@
+from docutils import nodes
+from docutils.parsers.rst.directives.tables import Table
+from docutils.parsers.rst.directives import unchanged_required
+from docutils.statemachine import ViewList
+import urllib.parse
+import pint.observatory
+import numpy as np
+
+_iptaclock_baseurl = "https://ipta.github.io/pulsar-clock-corrections"
+_googlesearch_baseurl = "https://www.google.com/maps/search/?"
+
+
+class SiteTable(Table):
+ option_spec = {"class": unchanged_required}
+ has_content = False
+
+ def run(self):
+ columns = [
+ ("Name / Aliases", "name", 10),
+ ("Origin", "origin", 50),
+ ("Location", "location", 20),
+ ("Clock File(s)", "clock", 20),
+ ]
+
+ class_ = None
+
+ table = nodes.table()
+ tgroup = nodes.tgroup(len(columns))
+ table += tgroup
+
+ thead = nodes.thead()
+ row = nodes.row()
+ for label, _, w in columns:
+ tgroup += nodes.colspec(colwidth=w)
+ entry = nodes.entry()
+ row += entry
+ entry += nodes.paragraph(text=label)
+ thead += row
+ tgroup += thead
+
+ tbody = nodes.tbody()
+ for name in sorted(pint.observatory.Observatory.names()):
+ o = pint.observatory.get_observatory(name)
+ row = nodes.row()
+ for _, c, _ in columns:
+ entry = nodes.entry()
+ row += entry
+ if c == "name":
+ entry += nodes.strong(text=name)
+ if len(o.aliases) > 0:
+ entry += nodes.paragraph(text=" (" + ", ".join(o.aliases) + ")")
+ elif c == "origin":
+ entry += nodes.paragraph(text=o.origin)
+ elif c == "location":
+ loc = o.earth_location_itrf()
+ if loc is not None:
+ lat = loc.lat.value
+ lon = loc.lon.value
+ text = f"{np.abs(lat):.4f}{'N' if lat >=0 else 'S'}, {np.abs(lon):.4f}{'E' if lon >= 0 else 'W'}"
+ # https://developers.google.com/maps/documentation/urls/get-started
+ url = _googlesearch_baseurl + urllib.parse.urlencode(
+ {"api": "1", "query": f"{lat},{lon}"}
+ )
+ para = nodes.paragraph()
+ refnode = nodes.reference("", "", internal=False, refuri=url)
+ innernode = nodes.emphasis(text, text)
+ refnode.append(innernode)
+ para += refnode
+ entry += para
+ elif c == "clock":
+ if hasattr(o, "clock_files"):
+ for clockfile in o.clock_files:
+ clockfilename = (
+ clockfile
+ if isinstance(clockfile, str)
+ else clockfile["name"]
+ )
+ para = nodes.paragraph()
+ dirname = "tempo" if o.clock_fmt == "tempo" else "T2runtime"
+ url = f"{_iptaclock_baseurl}/{dirname}/clock/{clockfilename}.html"
+ refnode = nodes.reference(
+ "", "", internal=False, refuri=url
+ )
+ innernode = nodes.emphasis(clockfilename, clockfilename)
+ refnode.append(innernode)
+ para += refnode
+ entry += para
+ tbody += row
+ tgroup += tbody
+ return [table]
+
+
+def setup(app):
+ app.add_directive("sitetable", SiteTable)
+
+ return {
+ "version": "0.1",
+ "parallel_read_safe": True,
+ "parallel_write_safe": True,
+ }
diff --git a/docs/command-line.rst b/docs/command-line.rst
index 04ba8ef76..fa04b56c8 100644
--- a/docs/command-line.rst
+++ b/docs/command-line.rst
@@ -111,3 +111,11 @@ the examples subdirectory of the PINT distro.
event_optimize J0030+0451_P8_15.0deg_239557517_458611204_ft1weights_GEO_wt.gt.0.4.fits PSRJ0030+0451_psrcat.par templateJ0030.3gauss --weightcol=PSRJ0030+0451 --minWeight=0.9 --nwalkers=100 --nsteps=500
+tcb2tdb
+-------
+
+A command line tool that converts par files from TCB timescale to TDB timescale.
+
+::
+
+ tcb2tdb J0030+0451_tcb.par J0030+0451_tdb.par
\ No newline at end of file
diff --git a/docs/conf.py b/docs/conf.py
index f0061ff76..f7da720bc 100755
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -15,7 +15,6 @@
import os
import sys
-from packaging.version import parse
import jupytext
@@ -58,6 +57,7 @@
"nbsphinx",
"paramtable",
"componentlist",
+ "sitetable",
#'IPython.sphinxext.ipython_console_highlighting',
]
diff --git a/docs/dependent-packages.rst b/docs/dependent-packages.rst
new file mode 100644
index 000000000..31267b05c
--- /dev/null
+++ b/docs/dependent-packages.rst
@@ -0,0 +1,24 @@
+Packages using PINT
+===================
+
+This page provides a list of some of the software packages that use PINT.
+
+* `APT `_
+* `baseband-tasks `_
+* `chimera `_
+* `darkmatter_and_pint `_
+* `ENTERPRISE `_
+* `enterprise_extensions `_
+* `enterprise_outliers `_
+* `GammaPulsar `_
+* `HENDRICS `_
+* `NICERSoft `_
+* `nustar-clock-utils `_
+* `pint_pal `_
+* `PsrSigSim `_
+* `pymars `_
+* `stingray `_
+* `tat-pulsar `_
+* `TOAextractor `_
+
+
diff --git a/docs/development-setup.rst b/docs/development-setup.rst
index 17598ace9..a5fa8737e 100644
--- a/docs/development-setup.rst
+++ b/docs/development-setup.rst
@@ -245,10 +245,10 @@ First make sure you are on the PINT master branch in the ``nanograv/PINT`` repos
Now wait 15 minutes and check that travis-ci says that the build is OK, before tagging!
If needed, push any bug fixes.
-Next, check the CHANGELOG and make sure all the significant changes from PRs since the last
-release have been documented.
-Then, change the # Unreleased section of the CHANGELOG to the version number you are about
-to tag and commit, but don't yet push.
+Next, check the unreleased CHANGELOG (`CHANGELOG-unreleased.md`) and make sure all the
+significant changes from PRs since the last release have been documented. Move these entries
+to the released CHANGELOG (`CHANGELOG.md`), and change title of the newly moved entries
+from "Unreleased" to the version number you are about to tag and commit. **But don't yet push**.
When tagging, always use "annotated tags" by specifying ``-a``, so do these commands to tag and push::
diff --git a/docs/examples/How_to_build_a_timing_model_component.py b/docs/examples/How_to_build_a_timing_model_component.py
index b92f33b7c..ab7f09640 100644
--- a/docs/examples/How_to_build_a_timing_model_component.py
+++ b/docs/examples/How_to_build_a_timing_model_component.py
@@ -28,7 +28,7 @@
# * Add the new component to the `TimingModel` class
# * Use the functions in the `TimingModel` class to interact with the new component.
#
-# We will build a simple model component, pulsar spindow model with spin period as parameters, instead of spin frequency.
+# We will build a simple model component, pulsar spindown model with spin period as parameters, instead of spin frequency.
# %% [markdown]
# ## Import the necessary modules
@@ -36,11 +36,11 @@
# %%
import numpy as np # Numpy is a widely used package
-# PINT uses astropy units in the internal cacluation and is highly recommended for a new component
+# PINT uses astropy units in the internal calculation and is highly recommended for a new component
import astropy.units as u
# Import the component classes.
-from pint.models.timing_model import TimingModel, Component, PhaseComponent
+from pint.models.spindown import SpindownBase
import pint.models.parameter as p
import pint.config
import pint.logging
@@ -90,15 +90,15 @@
# %%
-class PeriodSpindown(PhaseComponent):
- """This is an example model component of pular spindown but parametrized as period."""
+class PeriodSpindown(SpindownBase):
+ """This is an example model component of pulsar spindown but parametrized as period."""
register = True # Flags for the model builder to find this component.
# define the init function.
# Most components do not have a parameter for input.
def __init__(self):
- # Get the attruibutes that initilzed in the parent class
+ # Get the attributes that initialized in the parent class
super().__init__()
# Add parameters using the add_params in the TimingModel
# Add spin period as parameter
diff --git a/docs/examples/MCMC_walkthrough.broken b/docs/examples/MCMC_walkthrough.broken
index 4bc55a9ec..bb910e1b6 100644
--- a/docs/examples/MCMC_walkthrough.broken
+++ b/docs/examples/MCMC_walkthrough.broken
@@ -127,7 +127,8 @@ To make this run relatively fast for demonstration purposes, nsteps was purposef
```python
fitter.phaseogram()
-samples = sampler.sampler.chain[:, 10:, :].reshape((-1, fitter.n_fit_params))
+samples = np.transpose(sampler.sampler.get_chain(discard=10), (1, 0, 2)).reshape(
+ (-1, fitter.n_fit_params))
ranges = map(
lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
zip(*np.percentile(samples, [16, 50, 84], axis=0)),
@@ -192,7 +193,7 @@ fitter2.fit_toas(maxiter=nsteps2, pos=None)
```
```python
-samples2 = sampler2.sampler.chain[:, :, :].reshape((-1, fitter2.n_fit_params))
+samples2 = np.transpose(sampler2.sampler.get_chain(), (1, 0, 2)).reshape((-1, fitter2.n_fit_params))
ranges2 = map(
lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
zip(*np.percentile(samples2, [16, 50, 84], axis=0)),
diff --git a/docs/examples/PINT_walkthrough.py b/docs/examples/PINT_walkthrough.py
index 1a0d05809..64e4d45be 100644
--- a/docs/examples/PINT_walkthrough.py
+++ b/docs/examples/PINT_walkthrough.py
@@ -277,5 +277,5 @@
d = tempfile.mkdtemp()
pint.observatory.topo_obs.export_all_clock_files(d)
-for f in sorted(glob(d + "/*")):
+for f in sorted(glob(f"{d}/*")):
print(f)
diff --git a/docs/examples/bayesian-example-NGC6440E.py b/docs/examples/bayesian-example-NGC6440E.py
index adbb3bc60..daf571e8f 100644
--- a/docs/examples/bayesian-example-NGC6440E.py
+++ b/docs/examples/bayesian-example-NGC6440E.py
@@ -146,8 +146,9 @@
# TOAs in this dataset are from GBT).
# %%
-parfile = str(model) # casting the model to str gives the par file representation.
-parfile += "EFAC TEL gbt 1 1" # Add an EFAC to the par file and make it unfrozen.
+# casting the model to str gives the par file representation.
+# Add an EFAC to the par file and make it unfrozen.
+parfile = f"{str(model)}EFAC TEL gbt 1 1"
model2 = get_model(io.StringIO(parfile))
# %%
@@ -155,7 +156,7 @@
# Again, don't do this with real data. Use uninformative priors or priors
# motivated by previous experiments. This is done here with the sole purpose
# of making the run finish fast. Let us try this with the prior_info option now.
-prior_info = dict()
+prior_info = {}
for par in model2.free_params:
param = getattr(model2, par)
param_min = float(param.value - 10 * param.uncertainty_value)
@@ -201,4 +202,4 @@
print(f"Bayes factor : {bf} (in favor of no EFAC)")
# %% [markdown]
-# The Bayes factor tells us that the EFAC is unncessary for this dataset.
+# The Bayes factor tells us that the EFAC is unnecessary for this dataset.
diff --git a/docs/examples/compare_tempo2_J0613.py b/docs/examples/compare_tempo2_J0613.py
index d8e64d686..0d0f22fa3 100644
--- a/docs/examples/compare_tempo2_J0613.py
+++ b/docs/examples/compare_tempo2_J0613.py
@@ -1,4 +1,5 @@
"""Various tests to assess the performance of the J0623-0200."""
+
import pint.models.model_builder as mb
import pint.toa as toa
import pint.logging
@@ -42,7 +43,7 @@
presids_us = resids(toas, m).time_resids.to(u.us)
# Plot residuals
plt.errorbar(toas.get_mjds().value, presids_us.value, toas.get_errors().value, fmt="x")
-plt.title("%s Pre-Fit Timing Residuals" % m.PSR.value)
+plt.title(f"{m.PSR.value} Pre-Fit Timing Residuals")
plt.xlabel("MJD")
plt.ylabel("Residual (us)")
plt.grid()
diff --git a/docs/examples/fit_NGC6440E.py b/docs/examples/fit_NGC6440E.py
index 558d13d45..52ccb4030 100644
--- a/docs/examples/fit_NGC6440E.py
+++ b/docs/examples/fit_NGC6440E.py
@@ -61,7 +61,7 @@
# ```python
# # Define the timing model
# m = get_model(parfile)
-# # Read in the TOAs, using the solar system epemeris and other things from the model
+# # Read in the TOAs, using the solar system ephemeris and other things from the model
# t = pint.toa.get_TOAs(timfile, model=m)
# ```
diff --git a/docs/examples/fit_NGC6440E_MCMC.py b/docs/examples/fit_NGC6440E_MCMC.py
index dd047ea2d..ea4fa718c 100644
--- a/docs/examples/fit_NGC6440E_MCMC.py
+++ b/docs/examples/fit_NGC6440E_MCMC.py
@@ -26,12 +26,13 @@ def plot_chains(chain_dict, file=False):
fig.tight_layout()
if file:
fig.savefig(file)
- plt.close()
else:
plt.show()
- plt.close()
+
+ plt.close()
+import contextlib
import pint.config
parfile = pint.config.examplefile("NGC6440E.par.good")
@@ -51,7 +52,7 @@ def plot_chains(chain_dict, file=False):
rs = pint.residuals.Residuals(t, m).phase_resids
xt = t.get_mjds()
plt.plot(xt, rs, "x")
-plt.title("%s Pre-Fit Timing Residuals" % m.PSR.value)
+plt.title(f"{m.PSR.value} Pre-Fit Timing Residuals")
plt.xlabel("MJD")
plt.ylabel("Residual (phase)")
plt.grid()
@@ -83,21 +84,20 @@ def plot_chains(chain_dict, file=False):
# plotting the chains
chains = sampler.chains_to_dict(f.fitkeys)
-plot_chains(chains, file=f.model.PSR.value + "_chains.png")
+plot_chains(chains, file=f"{f.model.PSR.value}_chains.png")
# triangle plot
-# this doesn't include burn-in because we're not using it here, otherwise would have middle ':' --> 'burnin:'
-samples = sampler.sampler.chain[:, :, :].reshape((-1, f.n_fit_params))
-try:
+# this doesn't include burn-in because we're not using it here, otherwise set get_chain(discard=burnin)
+# samples = sampler.sampler.chain[:, :, :].reshape((-1, f.n_fit_params))
+samples = np.transpose(sampler.get_chain(), (1, 0, 2)).reshape((-1, ndim))
+with contextlib.suppress(ImportError):
import corner
fig = corner.corner(
samples, labels=f.fitkeys, bins=50, truths=f.maxpost_fitvals, plot_contours=True
)
- fig.savefig(f.model.PSR.value + "_triangle.png")
+ fig.savefig(f"{f.model.PSR.value}_triangle.png")
plt.close()
-except ImportError:
- pass
# Print some basic params
print("Best fit has reduced chi^2 of", f.resids.reduced_chi2)
@@ -112,7 +112,7 @@ def plot_chains(chain_dict, file=False):
t.get_errors().to(u.us).value,
fmt="x",
)
-plt.title("%s Post-Fit Timing Residuals" % m.PSR.value)
+plt.title(f"{m.PSR.value} Post-Fit Timing Residuals")
plt.xlabel("MJD")
plt.ylabel("Residual (us)")
plt.grid()
diff --git a/docs/examples/phase_offset_example.py b/docs/examples/phase_offset_example.py
new file mode 100644
index 000000000..bbb8b0fc5
--- /dev/null
+++ b/docs/examples/phase_offset_example.py
@@ -0,0 +1,139 @@
+#! /usr/bin/env python
+# ---
+# jupyter:
+# jupytext:
+# cell_metadata_filter: -all
+# formats: ipynb,py:percent
+# text_representation:
+# extension: .py
+# format_name: percent
+# format_version: '1.3'
+# jupytext_version: 1.14.4
+# kernelspec:
+# display_name: .env
+# language: python
+# name: python3
+# ---
+
+# %% [markdown]
+# # Demonstrate phase offset
+#
+# This notebook is primarily designed to operate as a plain `.py` script.
+# You should be able to run the `.py` script that occurs in the
+# `docs/examples/` directory in order to carry out a simple fit of a
+# timing model to some data. You should also be able to run the notebook
+# version as it is here (it may be necessary to `make notebooks` to
+# produce a `.ipynb` version using `jupytext`).
+
+# %%
+from pint.models import get_model_and_toas, PhaseOffset
+from pint.residuals import Residuals
+from pint.config import examplefile
+from pint.fitter import DownhillWLSFitter
+
+import matplotlib.pyplot as plt
+from astropy.visualization import quantity_support
+
+quantity_support()
+
+# %%
+# Read the TOAs and the model
+parfile = examplefile("J1028-5819-example.par")
+timfile = examplefile("J1028-5819-example.tim")
+model, toas = get_model_and_toas(parfile, timfile)
+
+# %%
+# Create a Residuals object
+res = Residuals(toas, model)
+
+# %%
+# By default, the residuals are mean-subtracted.
+resids1 = res.calc_phase_resids().to("")
+
+# We can disable mean subtraction by setting `subtract_mean` to False.
+resids2 = res.calc_phase_resids(subtract_mean=False).to("")
+
+# %%
+# Let us plot the residuals with and without mean subtraction.
+# In the bottom plot, there is clearly an offset between the two cases
+# although it is not so clear in the top plot.
+
+mjds = toas.get_mjds()
+errors = toas.get_errors() * model.F0.quantity
+
+plt.subplot(211)
+plt.errorbar(mjds, resids1, errors, ls="", marker="x", label="Mean subtracted")
+plt.errorbar(mjds, resids2, errors, ls="", marker="x", label="Not mean subtracted")
+plt.xlabel("MJD")
+plt.ylabel("Phase residuals")
+plt.axhline(0, ls="dotted", color="grey")
+plt.legend()
+
+plt.subplot(212)
+plt.plot(mjds, resids2 - resids1, ls="", marker="x")
+plt.xlabel("MJD")
+plt.ylabel("Phase residual difference")
+plt.show()
+
+# %%
+# This phase offset that gets subtracted implicitly can be computed
+# using the `calc_phase_mean` function. There is also a similar function
+# `calc_time_mean` for time offsets.
+
+implicit_offset = res.calc_phase_mean().to("")
+print("Implicit offset = ", implicit_offset)
+
+# %%
+# Now let us look at the design matrix.
+T, Tparams, Tunits = model.designmatrix(toas)
+print("Design matrix params :", Tparams)
+
+# The implicit offset is represented as "Offset".
+
+# %%
+# We can explicitly fit for this offset using the "PHOFF" parameter.
+# This is available in the "PhaseOffset" component
+
+po = PhaseOffset()
+model.add_component(po)
+assert hasattr(model, "PHOFF")
+
+# %%
+# Let us fit this now.
+
+model.PHOFF.frozen = False
+ftr = DownhillWLSFitter(toas, model)
+ftr.fit_toas()
+print(
+ f"PHOFF fit value = {ftr.model.PHOFF.value} +/- {ftr.model.PHOFF.uncertainty_value}"
+)
+
+# This is consistent with the implicit offset we got earlier.
+
+# %%
+# Let us plot the post-fit residuals.
+
+mjds = ftr.toas.get_mjds()
+errors = ftr.toas.get_errors() * model.F0.quantity
+resids = ftr.resids.calc_phase_resids().to("")
+
+plt.errorbar(mjds, resids, errors, ls="", marker="x", label="After fitting PHOFF")
+plt.xlabel("MJD")
+plt.ylabel("Phase residuals")
+plt.axhline(0, ls="dotted", color="grey")
+plt.legend()
+plt.show()
+
+# %%
+# Let us compute the phase residual mean again.
+phase_mean = ftr.resids.calc_phase_mean().to("")
+print("Phase residual mean = ", phase_mean)
+
+# i.e., we have successfully gotten rid of the offset by fitting PHOFF.
+
+# %%
+# Now let us look at the design matrix again.
+T, Tparams, Tunits = model.designmatrix(toas)
+print("Design matrix params :", Tparams)
+
+# The explicit offset "PHOFF" has replaced the implicit "Offset".
diff --git a/docs/explanation.rst b/docs/explanation.rst
index 35e37ac54..f1ee98869 100644
--- a/docs/explanation.rst
+++ b/docs/explanation.rst
@@ -147,12 +147,14 @@ in, and what kind of time you're asking for::
The conventional time scale for working with pulsars, and the one PINT
uses, is Barycentric Dynamical Time (TDB). You should be aware that there
-is another time scale, not yet supported in PINT, called Barycentric
-Coordinate Time (TCB), and that because of different handling of
-relativistic corrections, it does not advance at the same rate as TDB
+is another time scale, not yet fully supported in PINT, called Barycentric
+Coordinate Time (TCB). Because of different handling of relativistic
+corrections, the TCB timescale does not advance at the same rate as TDB
(there is also a many-second offset). TEMPO2 uses TCB by default, so
you may encounter pulsar timing models or even measurements that use
-TCB. PINT will attempt to detect this and let you know.
+TCB. PINT provides a command line tool `tcb2tdb` to approximately convert
+TCB timing models to TDB. PINT can also optionally convert TCB timing models
+to TDB (approximately) upon read.
Note that the need for leap seconds is because the Earth's rotation is
somewhat erratic - no, we're not about to be thrown off, but its
@@ -213,12 +215,49 @@ The total DM and dispersion slope predicted by a given timing model (:class:`pin
for a given set of TOAs (:class:`pint.toa.TOAs`) can be computed using :func:`pint.models.timing_model.TimingModel.total_dm`
and :func:`pint.models.timing_model.TimingModel.total_dispersion_slope` methods respectively.
+Offsets in pulsar timing
+------------------------
+Offsets arise in pulsar timing models for a variety of reasons. The different types of
+offsets are listed below:
+
+Overall phase offset
+''''''''''''''''''''
+The pulse phase corresponding to the TOAs are usually computed in reference to an arbitrary
+fiducial TOA known as the TZR TOA (see :class:`pint.models.absolute_phase.AbsPhase`). Since the
+choice of the TZR TOA is arbitrary, there can be an overall phase offset between the TZR TOA and
+the measured TOAs. There are three ways to account for this offset: (1) subtract the weighted mean
+from the timing residuals, (2) make the TZR TOA (given by the `TZRMJD` parameter) fittable, or
+(3) introduce a fittable phase offset parameter between measured TOAs and the TZR TOA.
+Traditionally, pulsar timing packages have opted to implicitly subtract the residual mean, and this
+is the default behavior of `PINT`. Option (2) is hard to implement because the TZR TOA may be
+specified at any observatory, and computing the TZR phase requires the application of the clock
+corrections. The explicit phase offset (option 3) can be invoked by adding the `PHOFF` parameter,
+(implemented in :class:`pint.models.phase_offset.PhaseOffset`). If the explicit offset `PHOFF`
+is given, the implicit residual mean subtraction behavior will be disabled.
+
+System-dependent delays
+'''''''''''''''''''''''
+It is very common to have TOAs for the same pulsar obtained using different observatories,
+telescope receivers, backend systems, and data processing pipelines, especially in long-running
+campaigns. Delays can arise between the TOAs measured using such different systems due to, among
+other reasons, instrumental delays, differences in algorithms used for RFI mitigation, folding, TOA
+measurement etc., and the choice of different template profiles used for TOA measurement. Such
+offsets are usually modeled using phase jumps (the `JUMP` parameter, see :class:`pint.models.jump.PhaseJump`)
+between TOAs generated from different systems.
+
+System-dependent DM offsets
+'''''''''''''''''''''''''''
+Similar to system-dependent delays, offsets can arise between wideband DM values measured using
+different systems due to the choice of template portraits with different fiducial DMs. This is
+usually modeled using DM jumps (the `DMJUMP` parameter, see :class:`pint.models.dispersion_model.DispersionJump`).
+
Observatories
-------------
-PINT comes with a number of defined observatories. Those on the surface of the Earth are :class:`~pint.observatory.topo_obs.TopoObs`
-instances. It can also pull in observatories from ``astropy``,
-and you can define your own. Observatories are generally referenced when reading TOA files, but can also be accessed directly::
+PINT comes with a number of defined observatories. Those on the surface of the Earth
+are :class:`~pint.observatory.topo_obs.TopoObs` instances. It can also pull in observatories
+from ``astropy``, and you can define your own. Observatories are generally referenced when
+reading TOA files, but can also be accessed directly::
import pint.observatory
gbt = pint.observatory.get_observatory("gbt")
@@ -228,7 +267,8 @@ Observatory definitions
Observatory definitions are included in ``pint.config.runtimefile("observatories.json")``.
To see the existing names, :func:`pint.observatory.Observatory.names_and_aliases` will
-return a dictionary giving all of the names (primary keys) and potential aliases (values).
+return a dictionary giving all of the names (primary keys) and potential aliases (values).
+You can also find the full list at :ref:`Observatory List`.
The observatory data are stored in JSON format. A simple example is::
@@ -244,12 +284,14 @@ The observatory data are stored in JSON format. A simple example is::
"origin": "The Robert C. Byrd Green Bank Telescope.\nThis data was obtained by Joe Swiggum from Ryan Lynch in 2021 September.\n"
}
-The observatory is defined by its name (``gbt``) and its position. This can be given as geocentric coordinates in the
-International_Terrestrial_Reference_System_ (ITRF) through the ``itrf_xyz`` triple (units as ``m``), or geodetic coordinates
-(WGS84_ assumed) through ``lat``, ``lon``, ``alt``
-(units are ``deg`` and ``m``). Conversion is done through Astropy_EarthLocation_.
+The observatory is defined by its name (``gbt``) and its position. This can be given as
+geocentric coordinates in the International_Terrestrial_Reference_System_ (ITRF) through
+the ``itrf_xyz`` triple (units as ``m``), or geodetic coordinates (WGS84_ assumed) through
+``lat``, ``lon``, ``alt`` (units are ``deg`` and ``m``). Conversion is done through
+Astropy_EarthLocation_.
-Other attributes are optional. Here we have also specified the ``tempo_code`` and ``itoa_code``, and a human-readable ``origin`` string.
+Other attributes are optional. Here we have also specified the ``tempo_code`` and
+``itoa_code``, and a human-readable ``origin`` string.
A more complex/complete example is::
@@ -280,19 +322,22 @@ A more complex/complete example is::
]
}
-Here we have included additional explicit ``aliases``, specified the clock format via ``clock_fmt``, and specified that the last entry in the
-clock file is bogus (``bogus_last_correction``). There are two clock files included in ``clock_file``:
+Here we have included additional explicit ``aliases``, specified the clock format via
+``clock_fmt``, and specified that the last entry in the clock file is bogus (``bogus_last_correction``).
+There are two clock files included in ``clock_file``:
* ``jbroach2jb.clk`` (where we also specify that it is ``valid_beyond_ends``)
* ``jb2gps.clk``
-These are combined to reference this particular telescope/instrument combination. For the full set of options, see :class:`~pint.observatory.topo_obs.TopoObs`.
+These are combined to reference this particular telescope/instrument combination.
+For the full set of options, see :class:`~pint.observatory.topo_obs.TopoObs`.
Adding New Observatories
''''''''''''''''''''''''
-In addition to modifying ``pint.config.runtimefile("observatories.json")``, there are other ways to add new observatories.
+In addition to modifying ``pint.config.runtimefile("observatories.json")``, there are other
+ways to add new observatories.
**Make sure you define any new observatory before you load any TOAs.**
1. You can define them pythonically:
@@ -302,7 +347,8 @@ In addition to modifying ``pint.config.runtimefile("observatories.json")``, ther
import astropy.coordinates
newobs = pint.observatory.topo_obs.TopoObs("newobs", location=astropy.coordinates.EarthLocation.of_site("keck"), origin="another way to get Keck")
-This can be done by specifying the ITRF coordinates, (``lat``, ``lon``, ``alt``), or a :class:`~astropy.coordinates.EarthLocation` instance.
+This can be done by specifying the ITRF coordinates, (``lat``, ``lon``, ``alt``), or a
+:class:`~astropy.coordinates.EarthLocation` instance.
2. You can include them just for the duration of your python session:
::
@@ -325,14 +371,17 @@ This can be done by specifying the ITRF coordinates, (``lat``, ``lon``, ``alt``)
}"""
load_observatories(io.StringIO(fakeGBT), overwrite=True)
-Note that since we are overwriting an existing observatory (rather than defining a completely new one) we specify ``overwrite=True``.
+Note that since we are overwriting an existing observatory (rather than defining a
+completely new one) we specify ``overwrite=True``.
-3. You can define them in a different file on disk. If you took the JSON above and put it into a file ``/home/user/anothergbt.json``,
+3. You can define them in a different file on disk. If you took the JSON above and
+put it into a file ``/home/user/anothergbt.json``,
you could then do::
export $PINT_OBS_OVERRIDE=/home/user/anothergbt.json
-(or the equivalent in your shell of choice) before you start any PINT scripts. By default this will overwrite any existing definitions.
+(or the equivalent in your shell of choice) before you start any PINT scripts.
+By default this will overwrite any existing definitions.
4. You can rely on ``astropy``. For instance:
::
@@ -340,7 +389,8 @@ you could then do::
import pint.observatory
keck = pint.observatory.Observatory.get("keck")
-will find Keck. :func:`astropy.coordinates.EarthLocation.get_site_names` will return a list of potential observatories.
+will find Keck. :func:`astropy.coordinates.EarthLocation.get_site_names` will return a list
+of potential observatories.
.. _International_Terrestrial_Reference_System: https://en.wikipedia.org/wiki/International_Terrestrial_Reference_System_and_Frame
.. _WGS84: https://en.wikipedia.org/wiki/World_Geodetic_System#WGS84
diff --git a/docs/howto.rst b/docs/howto.rst
index a27db55ea..0657a16e0 100644
--- a/docs/howto.rst
+++ b/docs/howto.rst
@@ -5,7 +5,8 @@ How-tos
This section is to provide detailed solutions to particular problems.
Some of these are based on user questions, others are more about how to develop PINT.
-Please feel free to change this by writing more; there are also some entries on the `PINT wiki `_ which you can contribute to more directly.
+Please feel free to change this by writing more; there are also some entries on the
+`PINT wiki `_ which you can contribute to more directly.
.. toctree::
diff --git a/docs/index.rst b/docs/index.rst
index 161ab164f..17cb82e53 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -68,8 +68,10 @@ communicate to them.
explanation
reference
howto
- authors
history
+ authors
+ dependent-packages
+
Indices and tables
==================
diff --git a/docs/installation.rst b/docs/installation.rst
index fd3819979..eb780c9ea 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -14,32 +14,29 @@ is more complicated (but not too much).
Prerequisites
-------------
-You need a python interpreter (either provided by your operating system or your favorite package manager).
-You should use Python 3.x -- it's time. Python 2 has been `sunset `_ as of January 1, 2020.
-Importantly, astropy versions 3 and later have completely dropped support for Python 2.
+PINT requires Python 3.8+ [1]_
-However, for PINT version 0.7.x and earlier both Python 2.7 and Python 3.5+ are supported.
+Your Python must have the package installation tool pip_ installed. Also make sure your ``setuptools`` are up to date (e.g. ``pip install -U setuptools``).
-For PINT versions 0.8 or later only Python 3.x will be supported.
+We highly recommend using a Conda_/`Anaconda `_ environment or the package isolation tool virtualenv_.
-Your Python must have the package installation tool pip_ installed. Also make sure your setuptools are up to date (e.g. ``pip install -U setuptools``).
-We highly recommend using an :ref:`Anaconda ` environment or the package isolation tool virtualenv_.
+IMPORTANT Notes!
+---------------
-TEMPO and Tempo2
-''''''''''''''''
+Naming conflict
+'''''''''''''''
-`TEMPO`_ is not required, but if you have it installed PINT can find clock
-correction files in ``$TEMPO/clock``
+PINT has a naming conflict with the `pint `_ units package available from PyPI (i.e. using pip) and conda.
+Do **NOT** ``pip install pint`` or ``conda install pint``! See :ref:`Basic Install via pip` or :ref:`Install with Anaconda`.
-`Tempo2`_ is not required, but if you have it installed PINT can find clock
-correction files in ``$TEMPO2/clock``
+Apple M1/M2 processors
+''''''''''''''''''''''
-IMPORTANT Note!
----------------
+PINT requires ``longdouble`` (80- or 128-bit floating point) arithmetic within ``numpy``, which is currently not supported natively on M1/M2 Macs.
+However, you can use an x86 version of ``conda`` even on an M1/M2 Mac: see `instructions for using Apple Intel packages on Apple silicon `_.
+It's possible to have `parallel versions of conda for x86 and ARM `_.
-PINT has a naming conflict with the `pint `_ units package available from PyPI (i.e. using pip) and conda.
-Do **NOT** ``pip install pint`` or ``conda install pint``! See below!
Basic Install via pip
---------------------
@@ -63,13 +60,10 @@ Install with Anaconda
---------------------
If you use `Anaconda `_ environments to manage your python packages,
-PINT is also available for Anaconda python under the `conda-forge `_ channel:
+PINT is also available for Anaconda python under the `conda-forge `_ channel::
$ conda install -c conda-forge pint-pulsar
-**NOTE**: PINT requires ``longdouble`` arithmetic within ``numpy``, which is currently not supported natively on M1 Macs (e.g., with the `ARM64 conda build `_). So it may be better to install the standard ``osx-64`` build and rely on Rosetta.
-
-
Install from Source
-------------------
@@ -145,8 +139,7 @@ by running ``pip install --user thing``. Unfortunately this causes something of
the same problem as having a ``PYTHONPATH`` set, where packages installed
outside your virtualenv can obscure the ones you have inside, producing bizarre
error messages. Record your current packages with ``pip freeze``, then try,
-outside a virtualenv, doing ``pip list`` with various options, and ``pip
-uninstall``; you shouldn't be able to uninstall anything system-wise (do not
+outside a virtualenv, doing ``pip list`` with various options, and ``pip uninstall``; you shouldn't be able to uninstall anything system-wise (do not
use ``sudo``!) and you shouldn't be able to uninstall anything in an inactive
virtualenv. So once you've blown away all those packages, you should be able to
work in clean virtualenvs. If you saved the output of ``pip freeze`` above, you
@@ -169,6 +162,7 @@ luck.
.. _virtualenv: https://virtualenv.pypa.io/en/latest/
.. _virtualenvwrapper: https://virtualenvwrapper.readthedocs.io/en/latest/
.. _Conda: https://docs.conda.io/en/latest/
+.. _Anaconda: https://www.anaconda.com
Installing PINT for Developers
------------------------------
@@ -194,10 +188,9 @@ Otherwise, there are several ways to `install pandoc`_
For further development instructions see :ref:`Developing PINT`
-.. _1: If you don't have `pip`_ installed, this `Python installation guide`_ can guide
- you through the process.
.. _pip: https://pip.pypa.io/en/stable/
-.. _TEMPO: http://tempo.sourceforge.net
-.. _Tempo2: https://bitbucket.org/psrsoft/tempo2
.. _pandoc: https://pandoc.org/
.. _`install pandoc`: https://pandoc.org/installing.html
+
+.. rubric:: Footnotes
+.. [1] Python 2.7 and 3.5+ are supported for PINT 0.7.x and earlier.
diff --git a/docs/observatory_list.rst b/docs/observatory_list.rst
new file mode 100644
index 000000000..c85acf597
--- /dev/null
+++ b/docs/observatory_list.rst
@@ -0,0 +1,8 @@
+.. _`Observatory List`:
+
+Observatory List
+================
+
+The current list of defined observatories is:
+
+.. sitetable::
diff --git a/docs/reference.rst b/docs/reference.rst
index 8e693d226..9ab55d08a 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -16,6 +16,7 @@ Useful starting places:
:maxdepth: 3
timingmodels
+ observatory_list
command-line
API reference <_autosummary/pint>
coding-style
diff --git a/docs/user-questions.rst b/docs/user-questions.rst
index a43d3f359..3a0584b6b 100644
--- a/docs/user-questions.rst
+++ b/docs/user-questions.rst
@@ -171,6 +171,34 @@ requested epoch. Similarly::
does the same for :class:`pint.models.astrometry.AstrometryEcliptic` (with an
optional specification of the obliquity).
+Convert between binary models
+-----------------------------
+
+If ``m`` is your initial model, say an ELL1 binary::
+
+ from pint import binaryconvert
+ m2 = binaryconvert.convert_binary(m, "DD")
+
+will convert it to a DD binary.
+
+Some binary types need additional parameters. For ELL1H, you can set the number of harmonics and whether to use H4 or STIGMA::
+
+ m2 = binaryconvert.convert_binary(m, "ELL1H", NHARMS=3, useSTIGMA=True)
+
+For DDK, you can set OM (known as ``KOM``)::
+
+ m2 = binaryconvert.convert_binary(mDD, "DDK", KOM=12 * u.deg)
+
+Parameter values and uncertainties will be converted. It will also make a best-guess as to which parameters should be frozen, but
+it can still be useful to refit with the new model and check which parameters are fit.
+
+.. note::
+ The T2 model from tempo2 is not implemented, as this is a complex model that actually encapsulates several models. The best practice is to
+ change the model to the actual underlying model (ELL1, DD, BT, etc).
+
+These conversions can also be done on the command line using ``convert_parfile``::
+
+ convert_parfile --binary=DD ell1.par -o dd.par
Add a jump programmatically
---------------------------
diff --git a/profiling/.gitignore b/profiling/.gitignore
new file mode 100644
index 000000000..f07669415
--- /dev/null
+++ b/profiling/.gitignore
@@ -0,0 +1,2 @@
+J0740+6620.cfr+19.tim
+bench_*_summary
\ No newline at end of file
diff --git a/profiling/README.txt b/profiling/README.txt
index f6239a4e0..59744b3da 100644
--- a/profiling/README.txt
+++ b/profiling/README.txt
@@ -16,7 +16,7 @@ curl -O https://data.nanograv.org/static/data/J0740+6620.cfr+19.tim
python high_level_benchmark.py
# To get useful output on an individual benchmarking script, do this to get
-# a list of the top 100 calls by execution time, as well as a PDF shwoing a tree of all the execution times.
+# a list of the top 100 calls by execution time, as well as a PDF showing a tree of all the execution times.
run_profile.py bench_MCMC.py
# The available benchmarks are (though more can be added!)
diff --git a/profiling/bench_chisq_grid.py b/profiling/bench_chisq_grid.py
index 34c098211..31909d265 100644
--- a/profiling/bench_chisq_grid.py
+++ b/profiling/bench_chisq_grid.py
@@ -36,8 +36,8 @@
thankftr_chi2grid = grid_chisq(thankftr, ("M2", "SINI"), (m2_grid, sini_grid), ncpu=1)
print()
-print("Number of TOAs: " + str(thanktoas.ntoas))
-print("Grid size of parameters: " + str(n) + "x" + str(n))
+print(f"Number of TOAs: {str(thanktoas.ntoas)}")
+print(f"Grid size of parameters: {n}x{n}")
print("Number of fits: 1")
print()
diff --git a/profiling/bench_chisq_grid_WLSFitter.py b/profiling/bench_chisq_grid_WLSFitter.py
index 07c799080..bb296536f 100644
--- a/profiling/bench_chisq_grid_WLSFitter.py
+++ b/profiling/bench_chisq_grid_WLSFitter.py
@@ -35,8 +35,8 @@
thankftr_chi2grid = grid_chisq(thankftr, ("M2", "SINI"), (m2_grid, sini_grid), ncpu=1)
print()
-print("Number of TOAs: " + str(thanktoas.ntoas))
-print("Grid size of parameters: " + str(n) + "x" + str(n))
+print(f"Number of TOAs: {str(thanktoas.ntoas)}")
+print(f"Grid size of parameters: {n}x{n}")
print("Number of fits: 1")
print()
diff --git a/profiling/bench_load_TOAs.py b/profiling/bench_load_TOAs.py
index 3a38c87bb..9532a41c3 100644
--- a/profiling/bench_load_TOAs.py
+++ b/profiling/bench_load_TOAs.py
@@ -16,5 +16,5 @@
include_bipm=True,
)
print()
-print("Number of TOAs: " + str(thanktoas.ntoas))
+print(f"Number of TOAs: {str(thanktoas.ntoas)}")
print()
diff --git a/profiling/high_level_benchmark.py b/profiling/high_level_benchmark.py
index fd89e03f1..d9c8ad5d2 100644
--- a/profiling/high_level_benchmark.py
+++ b/profiling/high_level_benchmark.py
@@ -11,18 +11,18 @@
import astropy
import numpy
import subprocess
-import cProfile
import pstats
import pint
import sys
import os
import platform
-from parser import parse_file
+import urllib.request
+from prfparser import parse_file
def bench_file(script):
outfile = script.replace(".py", "_prof_summary")
- cline = "python -m cProfile -o " + outfile + " " + script
+ cline = f"python -m cProfile -o {outfile} {script}"
print(cline)
# use DENULL to suppress logging output
subprocess.call(
@@ -33,30 +33,30 @@ def bench_file(script):
def get_results(script, outfile):
print("*******************************************************************")
- print("OUTPUT FOR " + script.upper() + ":")
- # put output in file for parsing
- f = open("bench.out", "w")
- old_stdout = sys.stdout
- sys.stdout = f
- # Check stats
- p = pstats.Stats(outfile)
- p.strip_dirs()
- # choose the functions to display
- if script == "bench_load_TOAs.py":
- p.print_stats("\(__init__", "toa")
- p.print_stats("\(apply_clock")
- p.print_stats("\(compute_TDBs")
- p.print_stats("\(compute_posvels")
- elif script == "bench_chisq_grid.py" or script == "bench_chisq_grid_WLSFitter.py":
- p.print_stats("\(get_designmatrix")
- p.print_stats("\(update_resid")
- p.print_stats("\(cho_factor")
- p.print_stats("\(cho_solve")
- p.print_stats("\(svd")
- p.print_stats("\(select_toa_mask")
- else:
- p.print_stats("only print total time") # for MCMC, only display total runtime
- f.close()
+ print(f"OUTPUT FOR {script.upper()}:")
+ with open("bench.out", "w") as f:
+ old_stdout = sys.stdout
+ sys.stdout = f
+ # Check stats
+ p = pstats.Stats(outfile)
+ p.strip_dirs()
+ # choose the functions to display
+ if script == "bench_load_TOAs.py":
+ p.print_stats("\(__init__", "toa")
+ p.print_stats("\(apply_clock")
+ p.print_stats("\(compute_TDBs")
+ p.print_stats("\(compute_posvels")
+ elif script in ["bench_chisq_grid.py", "bench_chisq_grid_WLSFitter.py"]:
+ p.print_stats("\(get_designmatrix")
+ p.print_stats("\(update_resid")
+ p.print_stats("\(cho_factor")
+ p.print_stats("\(cho_solve")
+ p.print_stats("\(svd")
+ p.print_stats("\(select_toa_mask")
+ else:
+ p.print_stats(
+ "only print total time"
+ ) # for MCMC, only display total runtime
# return output to terminal
sys.stdout = old_stdout
# parse file for desired info and format user-friendly output
@@ -68,7 +68,14 @@ def get_results(script, outfile):
parser = argparse.ArgumentParser(
description="High-level summary of python file timing."
)
- # scripts to be evaluated
+
+ if not os.path.isfile("J0740+6620.cfr+19.tim"):
+ print("Downloading data file J0740+6620.cfr+19.tim ...")
+ urllib.request.urlretrieve(
+ "https://data.nanograv.org/static/data/J0740+6620.cfr+19.tim",
+ "J0740+6620.cfr+19.tim",
+ )
+
script1 = "bench_load_TOAs.py"
script2 = "bench_chisq_grid.py"
script3 = "bench_chisq_grid_WLSFitter.py"
@@ -88,22 +95,17 @@ def get_results(script, outfile):
compID = cpuinfo.get_cpu_info()["brand_raw"]
else:
compID = "Unknown"
- print("Processor running this script: " + compID)
+ print(f"Processor running this script: {compID}")
pyversion = platform.python_version()
spversion = scipy.__version__
apversion = astropy.__version__
npversion = numpy.__version__
pintversion = pint.__version__
- print("Python version: " + pyversion)
+ print(f"Python version: {pyversion}")
print(
- "SciPy version: "
- + spversion
- + ", AstroPy version: "
- + apversion
- + ", NumPy version: "
- + npversion
+ f"SciPy version: {spversion}, AstroPy version: {apversion}, NumPy version: {npversion}"
)
- print("PINT version: " + pintversion)
+ print(f"PINT version: {pintversion}")
# output results
print()
diff --git a/profiling/parser.py b/profiling/prfparser.py
similarity index 94%
rename from profiling/parser.py
rename to profiling/prfparser.py
index 44c242c3a..35581af3a 100644
--- a/profiling/parser.py
+++ b/profiling/prfparser.py
@@ -1,4 +1,4 @@
-""" Parses inputed profiler output for function times. Designed specifically for
+""" Parses profiler output for function times. Designed specifically for
cProfile output files. Prints functions and times in neat, user-friendly layout.
Requires pandas package: https://pandas.pydata.org/
Install with: pip install pandas
@@ -15,8 +15,7 @@
def parse_line(line):
for key, attr in dictionary.items():
- match = attr.search(line)
- if match:
+ if match := attr.search(line):
return key, attr
return None, None
@@ -47,7 +46,7 @@ def parse_file(file):
# if match, read next line to get time
line = filename.readline()
n = 1
- # while there's a nonblank line under the keywork line...
+ # while there's a non-blank line under the keyword line...
while line.strip():
# extract values separated by spaces in the line, store in vals
vals = line.split()
diff --git a/profiling/run_profile.py b/profiling/run_profile.py
index 218e3e60c..0a237eae8 100755
--- a/profiling/run_profile.py
+++ b/profiling/run_profile.py
@@ -9,7 +9,8 @@
A .pdf file with the name + .pdf will be
generated for listing all the calls.
"""
-import cProfile
+
+
import argparse
import subprocess
import pstats
@@ -28,16 +29,12 @@
args = parser.parse_args()
outfile = args.script.replace(".py", "_profile")
if args.sort is None:
- cline = "python -m cProfile -o " + outfile + " " + args.script
+ cline = f"python -m cProfile -o {outfile} {args.script}"
else:
- cline = (
- "python -m cProfile -o " + outfile + " -s " + args.sort + " " + args.script
- )
+ cline = f"python -m cProfile -o {outfile} -s {args.sort} {args.script}"
print(cline)
subprocess.call(cline, shell=True)
- call_tree_line = (
- "gprof2dot -f pstats " + outfile + " | dot -Tpdf -o " + outfile + ".pdf"
- )
+ call_tree_line = f"gprof2dot -f pstats {outfile} | dot -Tpdf -o {outfile}.pdf"
subprocess.call(call_tree_line, shell=True)
# Check stats
p = pstats.Stats(outfile)
diff --git a/setup.cfg b/setup.cfg
index c62af7a06..80a0c6408 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -60,6 +60,7 @@ console_scripts =
pintk = pint.scripts.pintk:main
convert_parfile = pint.scripts.convert_parfile:main
compare_parfiles = pint.scripts.compare_parfiles:main
+ tcb2tdb = pint.scripts.tcb2tdb:main
# See the docstring in versioneer.py for instructions. Note that you must
diff --git a/src/pint/__init__.py b/src/pint/__init__.py
index 5f9acd0b9..f2bb62790 100644
--- a/src/pint/__init__.py
+++ b/src/pint/__init__.py
@@ -24,6 +24,7 @@
from pint.pulsar_ecliptic import PulsarEcliptic
from pint.pulsar_mjd import PulsarMJD, time_to_longdouble # ensure always loaded
+from pint.utils import info_string
__all__ = [
"__version__",
@@ -109,3 +110,8 @@
raise ValueError(
f"astropy version must be >=4 (currently it is {astropy.version.major})"
)
+
+
+def print_info():
+ """Print the OS version, Python version, PINT version, versions of the dependencies etc."""
+ print(info_string(detailed=True))
diff --git a/src/pint/bayesian.py b/src/pint/bayesian.py
index 2ec755079..8bb484b4b 100644
--- a/src/pint/bayesian.py
+++ b/src/pint/bayesian.py
@@ -94,20 +94,17 @@ def _decide_likelihood_method(self):
squares with normalization term (gls), for narrow-band (nb) or wide-band (wb)
dataset."""
- if (
- "NoiseComponent" in self.model.component_types
- and self.model.has_correlated_errors
+ if "NoiseComponent" not in self.model.component_types:
+ return "wls"
+ if correlated_errors_present := np.any(
+ [nc.introduces_correlated_errors for nc in self.model.NoiseComponent_list]
):
raise NotImplementedError(
"GLS likelihood for correlated noise is not yet implemented."
)
-
- if self.is_wideband:
- self.likelihood_method = "wls-wb"
- self._lnlikelihood = self._wls_wb_lnlikelihood
else:
- self.likelihood_method = "wls-nb"
- self._lnlikelihood = self._wls_nb_lnlikelihood
+ return "wls"
+ # return "gls"
def lnprior(self, params):
"""Basic implementation of a factorized log prior.
@@ -180,7 +177,7 @@ def lnposterior(self, params):
The value of the log-posterior at params
"""
lnpr = self.lnprior(params)
- return -np.inf if np.isnan(lnpr) else lnpr + self.lnlikelihood(params)
+ return lnpr + self.lnlikelihood(params) if np.isfinite(lnpr) else -np.inf
def _wls_nb_lnlikelihood(self, params):
"""Implementation of Log-Likelihood function for uncorrelated noise only for
diff --git a/src/pint/binaryconvert.py b/src/pint/binaryconvert.py
new file mode 100644
index 000000000..2b46e5cd2
--- /dev/null
+++ b/src/pint/binaryconvert.py
@@ -0,0 +1,1084 @@
+"""
+
+Potential issues:
+* orbital frequency derivatives
+* Does EPS1DOT/EPS2DOT imply OMDOT and vice versa?
+
+"""
+
+import numpy as np
+from astropy import units as u, constants as c
+from astropy.time import Time
+import copy
+from uncertainties import ufloat, umath
+from loguru import logger as log
+
+from pint import Tsun
+from pint.models.binary_bt import BinaryBT
+from pint.models.binary_dd import BinaryDD, BinaryDDS, BinaryDDGR
+from pint.models.binary_ddk import BinaryDDK
+from pint.models.binary_ell1 import BinaryELL1, BinaryELL1H, BinaryELL1k
+from pint.models.parameter import (
+ floatParameter,
+ MJDParameter,
+ intParameter,
+ funcParameter,
+)
+
+# output types
+# DDGR is not included as there is not a well-defined way to get a unique output
+binary_types = ["DD", "DDK", "DDS", "BT", "ELL1", "ELL1H", "ELL1k"]
+
+
+__all__ = ["convert_binary"]
+
+
+def _M2SINI_to_orthometric(model):
+ """Convert from standard Shapiro delay (M2, SINI) to orthometric (H3, H4, STIGMA)
+
+ Uses Eqns. 12, 20, 21 from Freire and Wex (2010)
+ Also propagates uncertainties if present
+
+ Note that both STIGMA and H4 should not be used
+
+ Paramters
+ ---------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ stigma : astropy.units.Quantity
+ h3 : astropy.units.Quantity
+ h4 : astropy.units.Quantity
+ stigma_unc : astropy.units.Quantity or None
+ Uncertainty on stigma
+ h3_unc : astropy.units.Quantity or None
+ Uncertainty on H3
+ h4_unc : astropy.units.Quantity or None
+ Uncertainty on H4
+
+ References
+ ----------
+ - Freire and Wex (2010), MNRAS, 409, 199 [1]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2010MNRAS.409..199F/abstract
+
+ """
+ if not (hasattr(model, "M2") and hasattr(model, "SINI")):
+ raise AttributeError(
+ "Model must contain M2 and SINI for conversion to orthometric parameters"
+ )
+ sini = model.SINI.as_ufloat()
+ m2 = model.M2.as_ufloat(u.Msun)
+ cbar = umath.sqrt(1 - sini**2)
+ stigma = sini / (1 + cbar)
+ h3 = Tsun.value * m2 * stigma**3
+ h4 = h3 * stigma
+
+ stigma_unc = stigma.s if stigma.s > 0 else None
+ h3_unc = h3.s * u.s if h3.s > 0 else None
+ h4_unc = h4.s * u.s if h4.s > 0 else None
+
+ return stigma.n, h3.n * u.s, h4.n * u.s, stigma_unc, h3_unc, h4_unc
+
+
+def _orthometric_to_M2SINI(model):
+ """Convert from orthometric (H3, H4, STIGMA) to standard Shapiro delay (M2, SINI)
+
+ Inverts Eqns. 12, 20, 21 from Freire and Wex (2010)
+ Also propagates uncertainties if present
+
+ If STIGMA is present will use that. Otherwise will use H4
+
+ Paramters
+ ---------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ M2 : astropy.units.Quantity.
+ SINI : astropy.units.Quantity
+ M2_unc : astropy.units.Quantity or None
+ Uncertainty on M2
+ SINI_unc : astropy.units.Quantity or None
+ Uncertainty on SINI
+
+ References
+ ----------
+ - Freire and Wex (2010), MNRAS, 409, 199 [1]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2010MNRAS.409..199F/abstract
+
+ """
+ if not (
+ hasattr(model, "H3") and (hasattr(model, "STIGMA") or hasattr(model, "H4"))
+ ):
+ raise AttributeError(
+ "Model must contain H3 and either STIGMA or H4 for conversion to M2/SINI"
+ )
+ h3 = model.H3.as_ufloat()
+ h4 = model.H4.as_ufloat() if model.H4.value is not None else None
+ stigma = model.STIGMA.as_ufloat() if model.STIGMA.value is not None else None
+
+ if stigma is not None:
+ sini = 2 * stigma / (1 + stigma**2)
+ m2 = h3 / stigma**3 / Tsun.value
+ else:
+ # FW10 Eqn. 25, 26
+ sini = 2 * h3 * h4 / (h3**2 + h4**2)
+ m2 = h3**4 / h4**3 / Tsun.value
+
+ m2_unc = m2.s * u.Msun if m2.s > 0 else None
+ sini_unc = sini.s if sini.s > 0 else None
+
+ return m2.n * u.Msun, sini.n, m2_unc, sini_unc
+
+
+def _SINI_to_SHAPMAX(model):
+ """Convert from standard SINI to alternate SHAPMAX parameterization
+
+ Also propagates uncertainties if present
+
+ Paramters
+ ---------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ SHAPMAX : astropy.units.Quantity
+ SHAPMAX_unc : astropy.units.Quantity or None
+ Uncertainty on SHAPMAX
+ """
+ if not hasattr(model, "SINI"):
+ raise AttributeError("Model must contain SINI for conversion to SHAPMAX")
+ sini = model.SINI.as_ufloat()
+ shapmax = -umath.log(1 - sini)
+ return shapmax.n, shapmax.s if shapmax.s > 0 else None
+
+
+def _SHAPMAX_to_SINI(model):
+ """Convert from alternate SHAPMAX to SINI parameterization
+
+ Also propagates uncertainties if present
+
+ Paramters
+ ---------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ SINI : astropy.units.Quantity
+ SINI_unc : astropy.units.Quantity or None
+ Uncertainty on SINI
+ """
+ if not hasattr(model, "SHAPMAX"):
+ raise AttributeError("Model must contain SHAPMAX for conversion to SINI")
+ shapmax = model.SHAPMAX.as_ufloat()
+ sini = 1 - umath.exp(-shapmax)
+ return sini.n, sini.s if sini.s > 0 else None
+
+
+def _from_ELL1(model):
+ """Convert from ELL1 parameterization to standard orbital parameterization
+
+ Converts using Eqns. 1, 2, and 3 from Lange et al. (2001)
+ Also computes EDOT if present
+ Also propagates uncertainties if present
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ ECC : astropy.units.Quantity
+ OM : astropy.units.Quantity
+ T0 : astropy.units.Quantity
+ EDOT : astropy.units.Quantity or None
+ OMDOT : astropy.units.Quantity or None
+ ECC_unc : astropy.units.Quantity or None
+ Uncertainty on ECC
+ OM_unc : astropy.units.Quantity or None
+ Uncertainty on OM
+ T0_unc : astropy.units.Quantity or None
+ Uncertainty on T0
+ EDOT_unc : astropy.units.Quantity or None
+ Uncertainty on EDOT
+ OMDOTDOT_unc : astropy.units.Quantity or None
+ Uncertainty on OMDOT
+
+ References
+ ----------
+ - Lange et al. (2001), MNRAS, 326, 274 [1]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2001MNRAS.326..274L/abstract
+
+ """
+ if model.BINARY.value not in ["ELL1", "ELL1H", "ELL1k"]:
+ raise ValueError(f"Requires model ELL1* rather than {model.BINARY.value}")
+
+ PB, PBerr = model.pb()
+ pb = ufloat(PB.to_value(u.d), PBerr.to_value(u.d) if PBerr is not None else 0)
+ eps1 = model.EPS1.as_ufloat()
+ eps2 = model.EPS2.as_ufloat()
+ om = umath.atan2(eps1, eps2)
+ if om < 0:
+ om += 2 * np.pi
+ ecc = umath.sqrt(eps1**2 + eps2**2)
+
+ tasc1, tasc2 = model.TASC.as_ufloats()
+ t01 = tasc1
+ t02 = tasc2 + (pb / 2 / np.pi) * om
+ T0 = Time(
+ t01.n,
+ val2=t02.n,
+ scale=model.TASC.quantity.scale,
+ precision=model.TASC.quantity.precision,
+ format="jd",
+ )
+ edot = None
+ omdot = None
+ if model.BINARY.value == "ELL1k":
+ lnedot = model.LNEDOT.as_ufloat(u.Hz)
+ edot = lnedot * ecc
+ omdot = model.OMDOT.as_ufloat(u.rad / u.s)
+
+ else:
+ if model.EPS1DOT.quantity is not None and model.EPS2DOT.quantity is not None:
+ eps1dot = model.EPS1DOT.as_ufloat(u.Hz)
+ eps2dot = model.EPS2DOT.as_ufloat(u.Hz)
+ edot = (eps1dot * eps1 + eps2dot * eps2) / ecc
+ omdot = (eps1dot * eps2 - eps2dot * eps1) / ecc**2
+
+ return (
+ ecc.n,
+ (om.n * u.rad).to(u.deg),
+ T0,
+ edot.n * u.Hz if edot is not None else None,
+ (omdot.n * u.rad / u.s).to(u.deg / u.yr) if omdot is not None else None,
+ ecc.s if ecc.s > 0 else None,
+ (om.s * u.rad).to(u.deg) if om.s > 0 else None,
+ t02.s * u.d if t02.s > 0 else None,
+ edot.s * u.Hz if (edot is not None and edot.s > 0) else None,
+ (omdot.s * u.rad / u.s).to(u.deg / u.yr)
+ if (omdot is not None and omdot.s > 0)
+ else None,
+ )
+
+
+def _to_ELL1(model):
+ """Convert from standard orbital parameterization to ELL1 parameterization
+
+ Converts using Eqns. 1, 2, and 3 from Lange et al. (2001)
+ Also computes EPS?DOT if present
+ Also propagates uncertainties if present
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ EPS1 : astropy.units.Quantity
+ EPS2 : astropy.units.Quantity
+ TASC : astropy.units.Quantity
+ EPS1DOT : astropy.units.Quantity or None
+ EPS2DOT : astropy.units.Quantity or None
+ EPS1_unc : astropy.units.Quantity or None
+ Uncertainty on EPS1
+ EPS2_unc : astropy.units.Quantity or None
+ Uncertainty on EPS2
+ TASC_unc : astropy.units.Quantity or None
+ Uncertainty on TASC
+ EPS1DOT_unc : astropy.units.Quantity or None
+ Uncertainty on EPS1DOT
+ EPS2DOT_unc : astropy.units.Quantity or None
+ Uncertainty on EPS2DOT
+
+ References
+ ----------
+ - Lange et al. (2001), MNRAS, 326, 274 [1]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2001MNRAS.326..274L/abstract
+
+ """
+ if not (hasattr(model, "ECC") and hasattr(model, "T0") and hasattr(model, "OM")):
+ raise AttributeError(
+ "Model must contain ECC, T0, OM for conversion to EPS1/EPS2"
+ )
+ ecc = model.ECC.as_ufloat()
+ om = model.OM.as_ufloat(u.rad)
+ eps1 = ecc * umath.sin(om)
+ eps2 = ecc * umath.cos(om)
+ PB, PBerr = model.pb()
+ pb = ufloat(PB.to_value(u.d), PBerr.to_value(u.d) if PBerr is not None else 0)
+ t01, t02 = model.T0.as_ufloats()
+ tasc1 = t01
+ tasc2 = t02 - (pb * om / 2 / np.pi)
+ TASC = Time(
+ tasc1.n,
+ val2=tasc2.n,
+ format="jd",
+ scale=model.T0.quantity.scale,
+ precision=model.T0.quantity.precision,
+ )
+ eps1dot = None
+ eps2dot = None
+ if model.EDOT.quantity is not None or model.OMDOT.quantity is not None:
+ if model.EDOT.quantity is not None:
+ edot = model.EDOT.as_ufloat(u.Hz)
+ else:
+ edot = ufloat(0, 0)
+ if model.OMDOT.quantity is not None:
+ omdot = model.OMDOT.as_ufloat(u.rad * u.Hz)
+ else:
+ omdot = ufloat(0, 0)
+ eps1dot = edot * umath.sin(om) + ecc * umath.cos(om) * omdot
+ eps2dot = edot * umath.cos(om) - ecc * umath.sin(om) * omdot
+ return (
+ eps1.n,
+ eps2.n,
+ TASC,
+ eps1dot.n * u.Hz,
+ eps2dot.n * u.Hz,
+ eps1.s if eps1.s > 0 else None,
+ eps2.s if eps2.s > 0 else None,
+ tasc2.s * u.d if tasc2.s > 0 else None,
+ eps1dot.s * u.Hz if (eps1dot is not None and eps1dot.s > 0) else None,
+ eps2dot.s * u.Hz if (eps2dot is not None and eps2dot.s > 0) else None,
+ )
+
+
+def _ELL1_to_ELL1k(model):
+ """Convert from ELL1 EPS1DOT/EPS2DOT to ELL1k LNEDOT/OMDOT
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ LNEDOT: astropy.units.Quantity
+ OMDOT: astropy.units.Quantity
+ LNEDOT_unc: astropy.units.Quantity or None
+ Uncertainty on LNEDOT
+ OMDOT_unc: astropy.units.Quantity or None
+ Uncertainty on OMDOT
+
+ References
+ ----------
+ - Susobhanan et al. (2018), MNRAS, 480 (4), 5260-5271 [1]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2018MNRAS.480.5260S/abstract
+ """
+ if model.BINARY.value not in ["ELL1", "ELL1H"]:
+ raise ValueError(f"Requires model ELL1/ELL1H rather than {model.BINARY.value}")
+ eps1 = model.EPS1.as_ufloat()
+ eps2 = model.EPS2.as_ufloat()
+ eps1dot = model.EPS1DOT.as_ufloat(u.Hz)
+ eps2dot = model.EPS2DOT.as_ufloat(u.Hz)
+ ecc = umath.sqrt(eps1**2 + eps2**2)
+ lnedot = (eps1 * eps1dot + eps2 * eps2dot) / ecc
+ omdot = (eps2 * eps1dot - eps1 * eps2dot) / ecc
+
+ with u.set_enabled_equivalencies(u.dimensionless_angles()):
+ lnedot_unc = lnedot.s / u.s if lnedot.s > 0 else None
+ omdot_unc = (omdot.s / u.s).to(u.deg / u.yr) if omdot.s > 0 else None
+ return lnedot.n / u.s, (omdot.n / u.s).to(u.deg / u.yr), lnedot_unc, omdot_unc
+
+
+def _ELL1k_to_ELL1(model):
+ """Convert from ELL1k LNEDOT/OMDOT to ELL1 EPS1DOT/EPS2DOT
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ EPS1DOT: astropy.units.Quantity
+ EPS2DOT: astropy.units.Quantity
+ EPS1DOT_unc: astropy.units.Quantity or None
+ Uncertainty on EPS1DOT
+ EPS2DOT_unc: astropy.units.Quantity or None
+ Uncertainty on EPS2DOT
+
+ References
+ ----------
+ - Susobhanan et al. (2018), MNRAS, 480 (4), 5260-5271 [1]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2018MNRAS.480.5260S/abstract
+ """
+ if model.BINARY.value != "ELL1k":
+ raise ValueError(f"Requires model ELL1k rather than {model.BINARY.value}")
+ eps1 = model.EPS1.as_ufloat()
+ eps2 = model.EPS2.as_ufloat()
+ lnedot = model.LNEDOT.as_ufloat(u.Hz)
+ with u.set_enabled_equivalencies(u.dimensionless_angles()):
+ omdot = model.OMDOT.as_ufloat(1 / u.s)
+ eps1dot = lnedot * eps1 + omdot * eps2
+ eps2dot = lnedot * eps2 - omdot * eps1
+
+ eps1dot_unc = eps1dot.s / u.s if eps1dot.s > 0 else None
+ eps2dot_unc = eps2dot.s / u.s if eps2dot.s > 0 else None
+ return eps1dot.n / u.s, eps2dot.n / u.s, eps1dot_unc, eps2dot_unc
+
+
+def _DDGR_to_PK(model):
+ """Convert DDGR model to equivalent PK parameters
+
+ Uses ``uncertainties`` module to propagate uncertainties
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+
+ Returns
+ -------
+ pbdot : uncertainties.core.Variable
+ gamma : uncertainties.core.Variable
+ omegadot : uncertainties.core.Variable
+ s : uncertainties.core.Variable
+ r : uncertainties.core.Variable
+ Dr : uncertainties.core.Variable
+ Dth : uncertainties.core.Variable
+ """
+ if model.BINARY.value != "DDGR":
+ raise ValueError(
+ f"Requires DDGR model for conversion, not '{model.BINARY.value}'"
+ )
+ tsun = Tsun.to_value(u.s)
+ mtot = model.MTOT.as_ufloat(u.Msun)
+ mc = model.M2.as_ufloat(u.Msun)
+ x = model.A1.as_ufloat()
+ PB, PBerr = model.pb()
+ pb = ufloat(PB.to_value(u.s), PBerr.to_value(u.s) if PBerr is not None else 0)
+ n = 2 * np.pi / pb
+ mp = mtot - mc
+ ecc = model.ECC.as_ufloat()
+ # units are seconds
+ gamma = (
+ tsun ** (2.0 / 3)
+ * n ** (-1.0 / 3)
+ * ecc
+ * (mc * (mp + 2 * mc) / (mp + mc) ** (4.0 / 3))
+ )
+ # units as seconds
+ r = tsun * mc
+ # units are radian/s
+ omegadot = (
+ (3 * tsun ** (2.0 / 3))
+ * n ** (5.0 / 3)
+ * (1 / (1 - ecc**2))
+ * (mp + mc) ** (2.0 / 3)
+ )
+ if model.XOMDOT.quantity is not None:
+ omegadot += model.XOMDOT.as_ufloat(u.rad / u.s)
+ fe = (1 + (73.0 / 24) * ecc**2 + (37.0 / 96) * ecc**4) / (1 - ecc**2) ** (
+ 7.0 / 2
+ )
+ # units as s/s
+ pbdot = (
+ (-192 * np.pi / 5)
+ * tsun ** (5.0 / 3)
+ * n ** (5.0 / 3)
+ * fe
+ * (mp * mc)
+ / (mp + mc) ** (1.0 / 3)
+ )
+ if model.XPBDOT.quantity is not None:
+ pbdot += model.XPBDOT.as_ufloat(u.s / u.s)
+ # dimensionless
+ s = tsun ** (-1.0 / 3) * n ** (2.0 / 3) * x * (mp + mc) ** (2.0 / 3) / mc
+ Dr = (
+ tsun ** (2.0 / 3)
+ * n ** (2.0 / 3)
+ * (3 * mp**2 + 6 * mp * mc + 2 * mc**2)
+ / (mp + mc) ** (4.0 / 3)
+ )
+ Dth = (
+ tsun ** (2.0 / 3)
+ * n ** (2.0 / 3)
+ * (3.5 * mp**2 + 6 * mp * mc + 2 * mc**2)
+ / (mp + mc) ** (4.0 / 3)
+ )
+ return pbdot, gamma, omegadot, s, r, Dr, Dth
+
+
+def _transfer_params(inmodel, outmodel, badlist=[]):
+ """Transfer parameters between an input and output model, excluding certain parameters
+
+ Parameters (input or output) that are :class:`~pint.models.parameter.funcParameter` are not copied
+
+ Parameters
+ ----------
+ inmodel : pint.models.timing_model.TimingModel
+ outmodel : pint.models.timing_model.TimingModel
+ badlist : list, optional
+ List of parameters to not transfer
+
+ """
+ inbinary_component_name = [
+ x for x in inmodel.components.keys() if x.startswith("Binary")
+ ][0]
+ outbinary_component_name = [
+ x for x in outmodel.components.keys() if x.startswith("Binary")
+ ][0]
+ for p in inmodel.components[inbinary_component_name].params:
+ if p not in badlist:
+ setattr(
+ outmodel.components[outbinary_component_name],
+ p,
+ copy.deepcopy(getattr(inmodel.components[inbinary_component_name], p)),
+ )
+
+
+def convert_binary(model, output, NHARMS=3, useSTIGMA=False, KOM=0 * u.deg):
+ """
+ Convert between binary models
+
+ Input models can be from :class:`~pint.models.binary_dd.BinaryDD`, :class:`~pint.models.binary_dd.BinaryDDS`,
+ :class:`~pint.models.binary_dd.BinaryDDGR`, :class:`~pint.models.binary_bt.BinaryBT`, :class:`~pint.models.binary_ddk.BinaryDDK`,
+ :class:`~pint.models.binary_ell1.BinaryELL1`, :class:`~pint.models.binary_ell1.BinaryELL1H`, :class:`~pint.models.binary_ell1.BinaryELL1k`,
+
+ Output models can be from :class:`~pint.models.binary_dd.BinaryDD`, :class:`~pint.models.binary_dd.BinaryDDS`,
+ :class:`~pint.models.binary_bt.BinaryBT`, :class:`~pint.models.binary_ddk.BinaryDDK`, :class:`~pint.models.binary_ell1.BinaryELL1`,
+ :class:`~pint.models.binary_ell1.BinaryELL1H`, :class:`~pint.models.binary_ell1.BinaryELL1k`
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+ output : str
+ Output model type
+ NHARMS : int, optional
+ Number of harmonics (``ELL1H`` only)
+ useSTIGMA : bool, optional
+ Whether to use STIGMA or H4 (``ELL1H`` only)
+ KOM : astropy.units.Quantity
+ Longitude of the ascending node (``DDK`` only)
+
+ Returns
+ -------
+ outmodel : pint.models.timing_model.TimingModel
+ """
+ # Do initial checks
+ if output not in binary_types:
+ raise ValueError(
+ f"Requested output binary '{output}' is not one of the known types ({binary_types})"
+ )
+
+ if not model.is_binary:
+ raise AttributeError("Input model is not a binary")
+
+ binary_component_name = [
+ x for x in model.components.keys() if x.startswith("Binary")
+ ][0]
+ binary_component = model.components[binary_component_name]
+ if binary_component.binary_model_name == output:
+ log.debug(
+ f"Input model and requested output are both of type '{output}'; returning copy"
+ )
+ return copy.deepcopy(model)
+ log.debug(f"Converting from '{binary_component.binary_model_name}' to '{output}'")
+
+ if binary_component.binary_model_name in ["ELL1", "ELL1H", "ELL1k"]:
+ # from ELL1, ELL1H, ELL1k
+ if output == "ELL1H":
+ # ELL1,ELL1k -> ELL1H
+ stigma, h3, h4, stigma_unc, h3_unc, h4_unc = _M2SINI_to_orthometric(model)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = ["M2", "SINI", "BINARY", "EDOT", "OMDOT"]
+ outmodel.add_component(BinaryELL1H(), validate=False)
+ if binary_component.binary_model_name == "ELL1k":
+ badlist += ["LNEDOT"]
+ EPS1DOT, EPS2DOT, EPS1DOT_unc, EPS2DOT_unc = _ELL1k_to_ELL1(model)
+ if EPS1DOT is not None:
+ outmodel.EPS1DOT.quantity = EPS1DOT
+ if EPS1DOT_unc is not None:
+ outmodel.EPS1DOT.uncertainty = EPS1DOT_unc
+ if EPS2DOT is not None:
+ outmodel.EPS2DOT.quantity = EPS2DOT
+ if EPS2DOT_unc is not None:
+ outmodel.EPS2DOT.uncertainty = EPS2DOT_unc
+ outmodel.EPS1DOT.frozen = model.LNEDOT.frozen or model.OMDOT.frozen
+ outmodel.EPS2DOT.frozen = model.LNEDOT.frozen or model.OMDOT.frozen
+ _transfer_params(model, outmodel, badlist)
+ outmodel.NHARMS.value = NHARMS
+ outmodel.H3.quantity = h3
+ outmodel.H3.uncertainty = h3_unc
+ outmodel.H3.frozen = model.M2.frozen or model.SINI.frozen
+ if useSTIGMA:
+ # use STIGMA and H3
+ outmodel.STIGMA.quantity = stigma
+ outmodel.STIGMA.uncertainty = stigma_unc
+ outmodel.STIGMA.frozen = outmodel.H3.frozen
+ else:
+ # use H4 and H3
+ outmodel.H4.quantity = h4
+ outmodel.H4.uncertainty = h4_unc
+ outmodel.H4.frozen = outmodel.H3.frozen
+ elif output in ["ELL1"]:
+ if model.BINARY.value == "ELL1H":
+ # ELL1H -> ELL1
+ M2, SINI, M2_unc, SINI_unc = _orthometric_to_M2SINI(model)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = ["H3", "H4", "STIGMA", "BINARY", "EDOT", "OMDOT"]
+ if output == "ELL1":
+ outmodel.add_component(BinaryELL1(), validate=False)
+ _transfer_params(model, outmodel, badlist)
+ outmodel.M2.quantity = M2
+ outmodel.SINI.quantity = SINI
+ if model.STIGMA.quantity is not None:
+ outmodel.M2.frozen = model.STIGMA.frozen or model.H3.frozen
+ outmodel.SINI.frozen = model.STIGMA.frozen
+ else:
+ outmodel.M2.frozen = model.STIGMA.frozen or model.H3.frozen
+ outmodel.SINI.frozen = model.STIGMA.frozen or model.H3.frozen
+ if M2_unc is not None:
+ outmodel.M2.uncertainty = M2_unc
+ if SINI_unc is not None:
+ outmodel.SINI.uncertainty = SINI_unc
+ elif model.BINARY.value == "ELL1k":
+ # ELL1k -> ELL1
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = ["BINARY", "LNEDOT", "OMDOT", "EDOT"]
+ if output == "ELL1":
+ outmodel.add_component(BinaryELL1(), validate=False)
+ EPS1DOT, EPS2DOT, EPS1DOT_unc, EPS2DOT_unc = _ELL1k_to_ELL1(model)
+ _transfer_params(model, outmodel, badlist)
+ if EPS1DOT is not None:
+ outmodel.EPS1DOT.quantity = EPS1DOT
+ if EPS1DOT_unc is not None:
+ outmodel.EPS1DOT.uncertainty = EPS1DOT_unc
+ if EPS2DOT is not None:
+ outmodel.EPS2DOT.quantity = EPS2DOT
+ if EPS2DOT_unc is not None:
+ outmodel.EPS2DOT.uncertainty = EPS2DOT_unc
+ outmodel.EPS1DOT.frozen = model.LNEDOT.frozen or model.OMDOT.frozen
+ outmodel.EPS2DOT.frozen = model.LNEDOT.frozen or model.OMDOT.frozen
+ elif output == "ELL1k":
+ if model.BINARY.value in ["ELL1"]:
+ # ELL1 -> ELL1k
+ LNEDOT, OMDOT, LNEDOT_unc, OMDOT_unc = _ELL1_to_ELL1k(model)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = ["BINARY", "EPS1DOT", "EPS2DOT", "OMDOT", "EDOT"]
+ outmodel.add_component(BinaryELL1k(), validate=False)
+ _transfer_params(model, outmodel, badlist)
+ outmodel.LNEDOT.quantity = LNEDOT
+ outmodel.OMDOT.quantity = OMDOT
+ if LNEDOT_unc is not None:
+ outmodel.LNEDOT.uncertainty = LNEDOT_unc
+ if OMDOT_unc is not None:
+ outmodel.OMDOT.uncertainty = OMDOT_unc
+ outmodel.LNEDOT.frozen = model.EPS1DOT.frozen or model.EPS2DOT.frozen
+ outmodel.OMDOT.frozen = model.EPS1DOT.frozen or model.EPS2DOT.frozen
+ elif model.BINARY.value == "ELL1H":
+ # ELL1H -> ELL1k
+ LNEDOT, OMDOT, LNEDOT_unc, OMDOT_unc = _ELL1_to_ELL1k(model)
+ M2, SINI, M2_unc, SINI_unc = _orthometric_to_M2SINI(model)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = [
+ "BINARY",
+ "EPS1DOT",
+ "EPS2DOT",
+ "H3",
+ "H4",
+ "STIGMA",
+ "OMDOT",
+ "EDOT",
+ ]
+ outmodel.add_component(BinaryELL1k(), validate=False)
+ _transfer_params(model, outmodel, badlist)
+ outmodel.LNEDOT.quantity = LNEDOT
+ outmodel.OMDOT.quantity = OMDOT
+ if LNEDOT_unc is not None:
+ outmodel.LNEDOT.uncertainty = LNEDOT_unc
+ if OMDOT_unc is not None:
+ outmodel.OMDOT.uncertainty = OMDOT_unc
+ outmodel.LNEDOT.frozen = model.EPS1DOT.frozen or model.EPS2DOT.frozen
+ outmodel.OMDOT.frozen = model.EPS1DOT.frozen or model.EPS2DOT.frozen
+ outmodel.M2.quantity = M2
+ outmodel.SINI.quantity = SINI
+ if model.STIGMA.quantity is not None:
+ outmodel.M2.frozen = model.STIGMA.frozen or model.H3.frozen
+ outmodel.SINI.frozen = model.STIGMA.frozen
+ else:
+ outmodel.M2.frozen = model.STIGMA.frozen or model.H3.frozen
+ outmodel.SINI.frozen = model.STIGMA.frozen or model.H3.frozen
+ if M2_unc is not None:
+ outmodel.M2.uncertainty = M2_unc
+ if SINI_unc is not None:
+ outmodel.SINI.uncertainty = SINI_unc
+ elif output in ["DD", "DDS", "DDK", "BT"]:
+ # (ELL1, ELL1k, ELL1H) -> (DD, DDS, DDK, BT)
+ # need to convert from EPS1/EPS2/TASC to ECC/OM/TASC
+ (
+ ECC,
+ OM,
+ T0,
+ EDOT,
+ OMDOT,
+ ECC_unc,
+ OM_unc,
+ T0_unc,
+ EDOT_unc,
+ OMDOT_unc,
+ ) = _from_ELL1(model)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = [
+ "ECC",
+ "OM",
+ "TASC",
+ "EPS1",
+ "EPS2",
+ "EPS1DOT",
+ "EPS2DOT",
+ "BINARY",
+ "OMDOT",
+ "EDOT",
+ ]
+ if output == "DD":
+ outmodel.add_component(BinaryDD(), validate=False)
+ elif output == "DDS":
+ outmodel.add_component(BinaryDDS(), validate=False)
+ badlist.append("SINI")
+ elif output == "DDK":
+ outmodel.add_component(BinaryDDK(), validate=False)
+ badlist.append("SINI")
+ elif output == "BT":
+ outmodel.add_component(BinaryBT(), validate=False)
+ badlist += ["M2", "SINI"]
+ if binary_component.binary_model_name == "ELL1H":
+ badlist += ["H3", "H4", "STIGMA", "VARSIGMA"]
+ _transfer_params(model, outmodel, badlist)
+ outmodel.ECC.quantity = ECC
+ outmodel.ECC.uncertainty = ECC_unc
+ outmodel.ECC.frozen = model.EPS1.frozen or model.EPS2.frozen
+ outmodel.OM.quantity = OM.to(u.deg, equivalencies=u.dimensionless_angles())
+ outmodel.OM.uncertainty = OM_unc.to(
+ u.deg, equivalencies=u.dimensionless_angles()
+ )
+ outmodel.OM.frozen = model.EPS1.frozen or model.EPS2.frozen
+ outmodel.T0.quantity = T0
+ outmodel.T0.uncertainty = T0_unc
+ if model.PB.quantity is not None:
+ outmodel.T0.frozen = (
+ model.EPS1.frozen
+ or model.EPS2.frozen
+ or model.TASC.frozen
+ or model.PB.frozen
+ )
+ elif model.FB0.quantity is not None:
+ outmodel.T0.frozen = (
+ model.EPS1.frozen
+ or model.EPS2.frozen
+ or model.TASC.frozen
+ or model.FB0.frozen
+ )
+ if EDOT is not None:
+ outmodel.EDOT.quantity = EDOT
+ if EDOT_unc is not None:
+ outmodel.EDOT.uncertainty = EDOT_unc
+ if OMDOT is not None:
+ outmodel.OMDOT.quantity = OMDOT
+ if OMDOT_unc is not None:
+ outmodel.OMDOT.uncertainty = OMDOT_unc
+ if binary_component.binary_model_name != "ELL1k":
+ outmodel.EDOT.frozen = model.EPS1DOT.frozen or model.EPS2DOT.frozen
+ outmodel.OMDOT.frozen = model.EPS1DOT.frozen or model.EPS2DOT.frozen
+ else:
+ outmodel.EDOT.frozen = model.LNEDOT.frozen
+ if binary_component.binary_model_name == "ELL1H":
+ M2, SINI, M2_unc, SINI_unc = _orthometric_to_M2SINI(model)
+ outmodel.M2.quantity = M2
+ outmodel.SINI.quantity = SINI
+ if M2_unc is not None:
+ outmodel.M2.uncertainty = M2_unc
+ if SINI_unc is not None:
+ outmodel.SINI.uncertainty = SINI_unc
+ if model.STIGMA.quantity is not None:
+ outmodel.SINI.frozen = model.STIGMA.frozen
+ outmodel.M2.frozen = model.STIGMA.frozen or model.H3.frozen
+ else:
+ outmodel.SINI.frozen = model.H3.frozen or model.H4.frozen
+ outmodel.M2.frozen = model.H3.frozen or model.H4.frozen
+ else:
+ raise ValueError(
+ f"Do not know how to convert from {binary_component.binary_model_name} to {output}"
+ )
+ elif binary_component.binary_model_name in ["DD", "DDGR", "DDS", "DDK", "BT"]:
+ if output in ["DD", "DDS", "DDK", "BT"]:
+ # (DD, DDGR, DDS, DDK, BT) -> (DD, DDS, DDK, BT)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = [
+ "BINARY",
+ ]
+ if binary_component.binary_model_name == "DDS":
+ badlist += ["SHAPMAX", "SINI"]
+ elif binary_component.binary_model_name == "DDK":
+ badlist += ["KIN", "KOM"]
+ elif binary_component.binary_model_name == "DDGR":
+ badlist += [
+ "PBDOT",
+ "OMDOT",
+ "GAMMA",
+ "DR",
+ "DTH",
+ "SINI",
+ "XOMDOT",
+ "XPBDOT",
+ ]
+ if output == "DD":
+ outmodel.add_component(BinaryDD(), validate=False)
+ elif output == "DDS":
+ outmodel.add_component(BinaryDDS(), validate=False)
+ badlist.append("SINI")
+ elif output == "DDK":
+ outmodel.add_component(BinaryDDK(), validate=False)
+ badlist.append("SINI")
+ elif output == "BT":
+ outmodel.add_component(BinaryBT(), validate=False)
+ badlist += ["M2", "SINI"]
+ _transfer_params(model, outmodel, badlist)
+ if binary_component.binary_model_name == "DDS":
+ SINI, SINI_unc = _SHAPMAX_to_SINI(model)
+ outmodel.SINI.quantity = SINI
+ if SINI_unc is not None:
+ outmodel.SINI.uncertainty = SINI_unc
+ elif binary_component.binary_model_name == "DDK":
+ if model.KIN.quantity is not None:
+ outmodel.SINI.quantity = np.sin(model.KIN.quantity)
+ if model.KIN.uncertainty is not None:
+ outmodel.SINI.uncertainty = np.abs(
+ model.KIN.uncertainty * np.cos(model.KIN.quantity)
+ ).to(
+ u.dimensionless_unscaled,
+ equivalencies=u.dimensionless_angles(),
+ )
+ outmodel.SINI.frozen = model.KIN.frozen
+ elif binary_component.binary_model_name == "DDGR":
+ pbdot, gamma, omegadot, s, r, Dr, Dth = _DDGR_to_PK(model)
+ outmodel.GAMMA.value = gamma.n
+ if gamma.s > 0:
+ outmodel.GAMMA.uncertainty_value = gamma.s
+ outmodel.PBDOT.value = pbdot.n
+ if pbdot.s > 0:
+ outmodel.PBDOT.uncertainty_value = pbdot.s
+ outmodel.OMDOT.value = (omegadot.n * u.rad / u.s).to_value(u.deg / u.yr)
+ if omegadot.s > 0:
+ outmodel.OMDOT.uncertainty_value = (
+ omegadot.s * u.rad / u.s
+ ).to_value(u.deg / u.yr)
+ outmodel.GAMMA.frozen = model.PB.frozen or model.M2.frozen
+ outmodel.OMDOT.frozen = (
+ model.PB.frozen or model.M2.frozen or model.ECC.frozen
+ )
+ outmodel.PBDOT.frozen = (
+ model.PB.frozen or model.M2.frozen or model.ECC.frozen
+ )
+ if output != "BT":
+ outmodel.DR.value = Dr.n
+ if Dr.s > 0:
+ outmodel.DR.uncertainty_value = Dr.s
+ outmodel.DTH.value = Dth.n
+ if Dth.s > 0:
+ outmodel.DTH.uncertainty_value = Dth.s
+ outmodel.DR.frozen = model.PB.frozen or model.M2.frozen
+ outmodel.DTH.frozen = model.PB.frozen or model.M2.frozen
+
+ if output == "DDS":
+ shapmax = -umath.log(1 - s)
+ outmodel.SHAPMAX.value = shapmax.n
+ if shapmax.s > 0:
+ outmodel.SHAPMAX.uncertainty_value = shapmax.s
+ outmodel.SHAPMAX.frozen = (
+ model.PB.frozen
+ or model.M2.frozen
+ or model.ECC.frozen
+ or model.A1.frozen
+ )
+ elif output == "DDK":
+ kin = umath.asin(s)
+ outmodel.KIN.value = kin.n
+ if kin.s > 0:
+ outmodel.KIN.uncertainty_value = kin.s
+ outmodel.KIN.frozen = (
+ model.PB.frozen
+ or model.M2.frozen
+ or model.ECC.frozen
+ or model.A1.frozen
+ )
+ log.warning(
+ f"Setting KIN={outmodel.KIN}: check that the sign is correct"
+ )
+ else:
+ outmodel.SINI.value = s.n
+ if s.s > 0:
+ outmodel.SINI.uncertainty_value = s.s
+ outmodel.SINI.frozen = (
+ model.PB.frozen
+ or model.M2.frozen
+ or model.ECC.frozen
+ or model.A1.frozen
+ )
+
+ elif output in ["ELL1", "ELL1H", "ELL1k"]:
+ # (DD, DDGR, DDS, DDK, BT) -> (ELL1, ELL1H, ELL1k)
+ outmodel = copy.deepcopy(model)
+ outmodel.remove_component(binary_component_name)
+ outmodel.BINARY.value = output
+ # parameters not to copy
+ badlist = ["BINARY", "ECC", "OM", "T0", "OMDOT", "EDOT", "GAMMA"]
+ if binary_component.binary_model_name == "DDS":
+ badlist += ["SHAPMAX", "SINI"]
+ elif binary_component.binary_model_name == "DDK":
+ badlist += ["KIN", "KOM"]
+ if output == "ELL1":
+ outmodel.add_component(BinaryELL1(), validate=False)
+ elif output == "ELL1H":
+ outmodel.add_component(BinaryELL1H(), validate=False)
+ badlist += ["M2", "SINI"]
+ elif output == "ELL1k":
+ outmodel.add_component(BinaryELL1k(), validate=False)
+ badlist += ["EPS1DOT", "EPS2DOT"]
+ badlist.remove("OMDOT")
+ _transfer_params(model, outmodel, badlist)
+ (
+ EPS1,
+ EPS2,
+ TASC,
+ EPS1DOT,
+ EPS2DOT,
+ EPS1_unc,
+ EPS2_unc,
+ TASC_unc,
+ EPS1DOT_unc,
+ EPS2DOT_unc,
+ ) = _to_ELL1(model)
+ LNEDOT = None
+ LNEDOT_unc = None
+ if output == "ELL1k":
+ if model.EDOT.quantity is not None and model.ECC.quantity is not None:
+ LNEDOT = model.EDOT.quantity / model.ECC.quantity
+ if (
+ model.EDOT.uncertainty is not None
+ and model.ECC.uncertainty is not None
+ ):
+ LNEDOT_unc = np.sqrt(
+ (model.EDOT.uncertainty / model.ECC.quantity) ** 2
+ + (
+ model.EDOT.quantity
+ * model.ECC.uncertainty
+ / model.ECC.quantity**2
+ )
+ ** 2
+ )
+ outmodel.EPS1.quantity = EPS1
+ outmodel.EPS2.quantity = EPS2
+ outmodel.TASC.quantity = TASC
+ outmodel.EPS1.uncertainty = EPS1_unc
+ outmodel.EPS2.uncertainty = EPS2_unc
+ outmodel.TASC.uncertainty = TASC_unc
+ outmodel.EPS1.frozen = model.ECC.frozen or model.OM.frozen
+ outmodel.EPS2.frozen = model.ECC.frozen or model.OM.frozen
+ outmodel.TASC.frozen = (
+ model.ECC.frozen
+ or model.OM.frozen
+ or model.PB.frozen
+ or model.T0.frozen
+ )
+ if EPS1DOT is not None and output != "ELL1k":
+ outmodel.EPS1DOT.quantity = EPS1DOT
+ outmodel.EPS2DOT.quantity = EPS2DOT
+ outmodel.EPS1DOT.frozen = model.EDOT.frozen or model.OM.frozen
+ outmodel.EPS2DOT.frozen = model.EDOT.frozen or model.OM.frozen
+ if EPS1DOT_unc is not None:
+ outmodel.EPS1DOT.uncertainty = EPS1DOT_unc
+ outmodel.EPS2DOT.uncertainty = EPS2DOT_unc
+ if LNEDOT is not None and output == "ELL1k":
+ outmodel.LNEDOT.quantity = LNEDOT
+ outmodel.LNEDOT.frozen = model.EDOT.frozen
+ if LNEDOT_unc is not None:
+ outmodel.LNEDOT.uncertainty = LNEDOT_unc
+ if binary_component.binary_model_name == "DDS":
+ SINI, SINI_unc = _SHAPMAX_to_SINI(model)
+ outmodel.SINI.quantity = SINI
+ if SINI_unc is not None:
+ outmodel.SINI.uncertainty = SINI_unc
+ elif binary_component.binary_model_name == "DDK":
+ if model.KIN.quantity is not None:
+ outmodel.SINI.quantity = np.sin(model.KIN.quantity)
+ if model.KIN.uncertainty is not None:
+ outmodel.SINI.uncertainty = np.abs(
+ model.KIN.uncertainty * np.cos(model.KIN.quantity)
+ ).to(
+ u.dimensionless_unscaled,
+ equivalencies=u.dimensionless_angles(),
+ )
+ outmodel.SINI.frozen = model.KIN.frozen
+ if output == "ELL1H":
+ if binary_component.binary_model_name == "DDGR":
+ model = convert_binary(model, "DD")
+ stigma, h3, h4, stigma_unc, h3_unc, h4_unc = _M2SINI_to_orthometric(
+ model
+ )
+ outmodel.NHARMS.value = NHARMS
+ outmodel.H3.quantity = h3
+ outmodel.H3.uncertainty = h3_unc
+ outmodel.H3.frozen = model.M2.frozen or model.SINI.frozen
+ if useSTIGMA:
+ # use STIGMA and H3
+ outmodel.STIGMA.quantity = stigma
+ outmodel.STIGMA.uncertainty = stigma_unc
+ outmodel.STIGMA.frozen = outmodel.H3.frozen
+ else:
+ # use H4 and H3
+ outmodel.H4.quantity = h4
+ outmodel.H4.uncertainty = h4_unc
+ outmodel.H4.frozen = outmodel.H3.frozen
+
+ if output == "DDS" and binary_component.binary_model_name != "DDGR":
+ SHAPMAX, SHAPMAX_unc = _SINI_to_SHAPMAX(model)
+ outmodel.SHAPMAX.quantity = SHAPMAX
+ outmodel.SHAPMAX.uncertainty = SHAPMAX_unc
+ outmodel.SHAPMAX.frozen = model.SINI.frozen
+
+ if output == "DDK":
+ outmodel.KOM.quantity = KOM
+ if binary_component.binary_model_name != "DDGR":
+ if model.SINI.quantity is not None:
+ outmodel.KIN.quantity = np.arcsin(model.SINI.quantity).to(
+ u.deg, equivalencies=u.dimensionless_angles()
+ )
+ if model.SINI.uncertainty is not None:
+ outmodel.KIN.uncertainty = (
+ model.SINI.uncertainty / np.sqrt(1 - model.SINI.quantity**2)
+ ).to(u.deg, equivalencies=u.dimensionless_angles())
+ log.warning(
+ f"Setting KIN={outmodel.KIN} from SINI={model.SINI}: check that the sign is correct"
+ )
+ outmodel.KIN.frozen = model.SINI.frozen
+ outmodel.validate()
+
+ return outmodel
diff --git a/src/pint/config.py b/src/pint/config.py
index b0b941813..09754b852 100644
--- a/src/pint/config.py
+++ b/src/pint/config.py
@@ -32,7 +32,8 @@ def examplefile(filename):
Notes
-----
- This is **not** for files needed at runtime. Those are located by :func:`pint.config.runtimefile`. This is for files needed for the example notebooks.
+ This is **not** for files needed at runtime. Those are located by :func:`pint.config.runtimefile`.
+ This is for files needed for the example notebooks.
"""
return pkg_resources.resource_filename(
__name__, os.path.join("data/examples/", filename)
@@ -53,7 +54,8 @@ def runtimefile(filename):
Notes
-----
- This **is** for files needed at runtime. Files needed for the example notebooks are found via :func:`pint.config.examplefile`.
+ This **is** for files needed at runtime. Files needed for the example notebooks
+ are found via :func:`pint.config.examplefile`.
"""
return pkg_resources.resource_filename(
__name__, os.path.join("data/runtime/", filename)
diff --git a/src/pint/data/examples/J1028-5819-example.par b/src/pint/data/examples/J1028-5819-example.par
new file mode 100644
index 000000000..ef066f5c2
--- /dev/null
+++ b/src/pint/data/examples/J1028-5819-example.par
@@ -0,0 +1,25 @@
+# Created: 2023-05-17T14:02:21.738913
+# PINT_version: 0.9.5+145.ga1930cdf.dirty
+# User: Abhimanyu Susobhanan (abhimanyu)
+# Host: abhimanyu-VirtualBox
+# OS: Linux-5.19.0-41-generic-x86_64-with-glibc2.35
+# Format: pint
+# Converted from tempo2 example file "example1.par"
+PSRJ J1028-5819
+UNITS TDB
+DILATEFREQ N
+DMDATA N
+NTOA 0
+CHI2 0.0
+RAJ 10:28:28.00000000 1 0.00000000000000000000
+DECJ -58:19:05.20000000 1 0.00000000000000000000
+PMRA 0.0
+PMDEC 0.0
+PX 0.0
+POSEPOCH 54561.9998229616586976
+F0 10.940532469635118635 1 0.0
+F1 -1.9300000598500644258e-12 1 0.0
+PEPOCH 54561.9998229616586976
+PLANET_SHAPIRO N
+DM 96.525001496639228994
+DMEPOCH 54561.9998229616586976
diff --git a/src/pint/data/examples/J1028-5819-example.tim b/src/pint/data/examples/J1028-5819-example.tim
new file mode 100644
index 000000000..9e90e3ac0
--- /dev/null
+++ b/src/pint/data/examples/J1028-5819-example.tim
@@ -0,0 +1,67 @@
+FORMAT 1
+C Created: 2023-05-17T14:02:41.772745
+C PINT_version: 0.9.5+145.ga1930cdf.dirty
+C User: Abhimanyu Susobhanan (abhimanyu)
+C Host: abhimanyu-VirtualBox
+C OS: Linux-5.19.0-41-generic-x86_64-with-glibc2.35
+C Converted from tempo2 example file "example1.tim"
+fake.rf 1440.000000 53999.9949438210586574 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54033.9392599716457176 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54067.8140320096647454 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54101.7199188700690741 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54135.6176501855466436 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54169.5843236848327199 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54203.4800157137875231 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54237.3510975335328704 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54271.2991333712171065 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54305.1832830773148611 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54339.0968294235198727 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54372.9913715311699653 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54406.9205437255239236 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54440.8153243558259143 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54474.6859209558068866 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54508.5960892047673380 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54542.5266276072368750 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54576.4269947742438079 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54610.3821344692018750 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54644.2505184108899537 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54678.2069060193461690 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54712.0951898580254630 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54746.0111591634952662 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54779.8627460627891783 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54813.8238687929544328 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54847.7106791650223843 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54881.6404401986973610 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54915.5493542034817939 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54949.3823428037153819 800.000 parkes -format Tempo2
+fake.rf 1440.000000 54983.3545636637380440 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55017.2371734214493403 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55051.1171564791278241 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55085.0431951390837847 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55118.9767443743058682 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55152.8554486768006483 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55186.7720530367924305 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55220.6817774476756249 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55254.5490093732091434 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55288.5128635786631944 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55322.3801252292781250 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55356.3076750540482986 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55390.1945757767493171 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55424.1325480043075810 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55457.9987714308363310 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55491.9566212987769328 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55525.8580247742897801 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55559.7272517251385185 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55593.6251412496792130 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55627.6030797493924306 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55661.4369474826314815 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55695.3793900871958217 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55729.3232170576581482 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55763.2215356963027547 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55797.0894497030695602 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55830.9776975216997686 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55864.9299751509429282 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55898.8608666088026273 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55932.7717109066336459 800.000 parkes -format Tempo2
+fake.rf 1440.000000 55966.6424032251585881 800.000 parkes -format Tempo2
+fake.rf 1440.000000 56000.5824517227911342 800.000 parkes -format Tempo2
diff --git a/src/pint/data/runtime/observatories.json b/src/pint/data/runtime/observatories.json
index 1c2bb6489..2739dde26 100644
--- a/src/pint/data/runtime/observatories.json
+++ b/src/pint/data/runtime/observatories.json
@@ -256,7 +256,7 @@
"origin": "The Giant Metre-wave Radio Telescope.\nThe origin of this data is unknown but as of 2021 June 8 it agrees exactly with\nthe values used by TEMPO and TEMPO2.\nGMRT does not need clock files as the data is recorded against UTC(gps)."
},
"ort": {
- "aliases" : [
+ "aliases": [
"or"
],
"clock_fmt": "tempo2",
@@ -483,7 +483,7 @@
-4123529.78,
4147966.36
],
- "origin": "The Allan telescope array.\nOrigin of this data is unknown but as of 2021 June 8 this value agrees exactly with\nthe value used by TEMPO2.\n"
+ "origin": "The Allen telescope array.\nOrigin of this data is unknown but as of 2021 June 8 this value agrees exactly with\nthe value used by TEMPO2.\n"
},
"ccera": {
"itrf_xyz": [
@@ -1392,13 +1392,23 @@
],
"origin": "Imported from TEMPO obsys.dat 2021 June 8."
},
- "hess": {
- "include_bipm": false,
- "itrf_xyz": [
- 5622462.3793,
- 1665449.2317,
- -2505096.8054
- ],
- "origin": "H.E.S.S., the High Energy Stereoscopic System, an Imaging Atmospheric Cherenkov Telescope.\nThese coordinates are provided (from geodetic coordinates : lat = 23°16'18''S, lon = 16°30'00''E at 1800m asl) by the Maxime Regeard and Arache Djannati-Atai on behalf of the H.E.S.S. Collabotation, 2023 January 18."
- }
-}
+ "hess": {
+ "include_bipm": false,
+ "itrf_xyz": [
+ 5622462.3793,
+ 1665449.2317,
+ -2505096.8054
+ ],
+ "origin": "H.E.S.S., the High Energy Stereoscopic System, an Imaging Atmospheric Cherenkov Telescope.\nThese coordinates are provided (from geodetic coordinates : lat = 23°16'18''S, lon = 16°30'00''E at 1800m asl) by the Maxime Regeard and Arache Djannati-Atai on behalf of the H.E.S.S. Collabotation, 2023 January 18."
+ },
+ "hawc": {
+ "include_bipm": false,
+ "include_gps": true,
+ "itrf_xyz": [
+ -767864.42219771,
+ -5987810.65796205,
+ 2064148.57332404
+ ],
+ "origin": "HAWC: High Altitude Water Cherenkov Experiment.\nCoordinates provided (as geodetic coordinates : 18:59:41.63 N, 97:18:27.39 W, 4096 m) by the spokesperson Ke Fang on behalf of the HAWC Collabotation, 2023 April 21."
+ }
+}
\ No newline at end of file
diff --git a/src/pint/derived_quantities.py b/src/pint/derived_quantities.py
index 816f0b519..104b73d8d 100644
--- a/src/pint/derived_quantities.py
+++ b/src/pint/derived_quantities.py
@@ -10,6 +10,8 @@
__all__ = [
"a1sini",
"companion_mass",
+ "dr",
+ "dth",
"gamma",
"mass_funct",
"mass_funct2",
@@ -24,6 +26,7 @@
"pulsar_edot",
"pulsar_mass",
"shklovskii_factor",
+ "sini",
]
@@ -63,12 +66,12 @@ def p_to_f(p, pd, pdd=None):
fd = -pd / (p * p)
if pdd is None:
return [f, fd]
- else:
- if pdd == 0.0:
- fdd = 0.0 * f.unit / (u.s**2)
- else:
- fdd = 2.0 * pd * pd / (p**3.0) - pdd / (p * p)
- return [f, fd, fdd]
+ fdd = (
+ 0.0 * f.unit / (u.s**2)
+ if pdd == 0.0
+ else 2.0 * pd * pd / (p**3.0) - pdd / (p * p)
+ )
+ return [f, fd, fdd]
@u.quantity_input(
@@ -115,14 +118,13 @@ def pferrs(porf, porferr, pdorfd=None, pdorfderr=None):
"""
if pdorfd is None:
return [1.0 / porf, porferr / porf**2.0]
- else:
- forperr = porferr / porf**2.0
- fdorpderr = np.sqrt(
- (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0
- + pdorfderr**2.0 / porf**4.0
- )
- [forp, fdorpd] = p_to_f(porf, pdorfd)
- return [forp, forperr, fdorpd, fdorpderr]
+ forperr = porferr / porf**2.0
+ fdorpderr = np.sqrt(
+ (4.0 * pdorfd**2.0 * porferr**2.0) / porf**6.0
+ + pdorfderr**2.0 / porf**4.0
+ )
+ [forp, fdorpd] = p_to_f(porf, pdorfd)
+ return [forp, forperr, fdorpd, fdorpderr]
@u.quantity_input(fo=u.Hz)
@@ -291,8 +293,8 @@ def mass_funct(pb: u.d, x: u.cm):
----------
pb : astropy.units.Quantity
Binary period
- x : astropy.units.Quantity in ``pint.ls``
- Semi-major axis, A1SINI, in units of ls
+ x : astropy.units.Quantity
+ Semi-major axis, A1SINI, in units of ``pint.ls``
Returns
-------
@@ -712,6 +714,160 @@ def omdot(mp: u.Msun, mc: u.Msun, pb: u.d, e: u.dimensionless_unscaled):
return value.to(u.deg / u.yr, equivalencies=u.dimensionless_angles())
+@u.quantity_input
+def sini(mp: u.Msun, mc: u.Msun, pb: u.d, x: u.cm):
+ """Post-Keplerian sine of inclination, assuming general relativity.
+
+ Can handle scalar or array inputs.
+
+ Parameters
+ ----------
+ mp : astropy.units.Quantity
+ pulsar mass
+ mc : astropy.units.Quantity
+ companion mass
+ pb : astropy.units.Quantity
+ Binary orbital period
+ x : astropy.units.Quantity
+ Semi-major axis, A1SINI, in units of ``pint.ls``
+
+ Returns
+ -------
+ sini : astropy.units.Quantity
+
+ Raises
+ ------
+ astropy.units.UnitsError
+ If the input data are not appropriate quantities
+ TypeError
+ If the input data are not quantities
+
+ Notes
+ -----
+ Calculates
+
+ .. math::
+
+ s = T_{\odot}^{-1/3} \\left(\\frac{P_b}{2\pi}\\right)^{-2/3}
+ \\frac{(m_p+m_c)^{2/3}}{m_c}
+
+ with :math:`T_\odot = GM_\odot c^{-3}`.
+
+ More details in :ref:`Timing Models`. Also see [11]_.
+
+ .. [11] Lorimer & Kramer, 2008, "The Handbook of Pulsar Astronomy", Eqn. 8.51
+
+ """
+
+ return (
+ (const.G) ** (-1.0 / 3)
+ * (pb / 2 / np.pi) ** (-2.0 / 3)
+ * x
+ * (mp + mc) ** (2.0 / 3)
+ / mc
+ ).decompose()
+
+
+@u.quantity_input
+def dr(mp: u.Msun, mc: u.Msun, pb: u.d):
+ """Post-Keplerian Roemer delay term
+
+ dr (:math:`\delta_r`) is part of the relativistic deformation of the orbit
+
+ Parameters
+ ----------
+ mp : astropy.units.Quantity
+ pulsar mass
+ mc : astropy.units.Quantity
+ companion mass
+ pb : astropy.units.Quantity
+ Binary orbital period
+
+ Returns
+ -------
+ dr : astropy.units.Quantity
+
+ Raises
+ ------
+ astropy.units.UnitsError
+ If the input data are not appropriate quantities
+ TypeError
+ If the input data are not quantities
+
+ Notes
+ -----
+ Calculates
+
+ .. math::
+
+ \delta_r = T_{\odot}^{2/3} \\left(\\frac{P_b}{2\pi}\\right)^{2/3}
+ \\frac{3 m_p^2+6 m_p m_c +2m_c^2}{(m_p+m_c)^{4/3}}
+
+ with :math:`T_\odot = GM_\odot c^{-3}`.
+
+ More details in :ref:`Timing Models`. Also see [12]_.
+
+ .. [12] Lorimer & Kramer, 2008, "The Handbook of Pulsar Astronomy", Eqn. 8.54
+
+ """
+ return (
+ (const.G / const.c**3) ** (2.0 / 3)
+ * (2 * np.pi / pb) ** (2.0 / 3)
+ * (3 * mp**2 + 6 * mp * mc + 2 * mc**2)
+ / (mp + mc) ** (4 / 3)
+ ).decompose()
+
+
+@u.quantity_input
+def dth(mp: u.Msun, mc: u.Msun, pb: u.d):
+ """Post-Keplerian Roemer delay term
+
+ dth (:math:`\delta_{\\theta}`) is part of the relativistic deformation of the orbit
+
+ Parameters
+ ----------
+ mp : astropy.units.Quantity
+ pulsar mass
+ mc : astropy.units.Quantity
+ companion mass
+ pb : astropy.units.Quantity
+ Binary orbital period
+
+ Returns
+ -------
+ dth : astropy.units.Quantity
+
+ Raises
+ ------
+ astropy.units.UnitsError
+ If the input data are not appropriate quantities
+ TypeError
+ If the input data are not quantities
+
+ Notes
+ -----
+ Calculates
+
+ .. math::
+
+ \delta_{\\theta} = T_{\odot}^{2/3} \\left(\\frac{P_b}{2\pi}\\right)^{2/3}
+ \\frac{3.5 m_p^2+6 m_p m_c +2m_c^2}{(m_p+m_c)^{4/3}}
+
+ with :math:`T_\odot = GM_\odot c^{-3}`.
+
+ More details in :ref:`Timing Models`. Also see [13]_.
+
+ .. [13] Lorimer & Kramer, 2008, "The Handbook of Pulsar Astronomy", Eqn. 8.55
+
+ """
+ return (
+ (const.G / const.c**3) ** (2.0 / 3)
+ * (2 * np.pi / pb) ** (2.0 / 3)
+ * (3.5 * mp**2 + 6 * mp * mc + 2 * mc**2)
+ / (mp + mc) ** (4 / 3)
+ ).decompose()
+
+
@u.quantity_input
def omdot_to_mtot(omdot: u.deg / u.yr, pb: u.d, e: u.dimensionless_unscaled):
"""Determine total mass from Post-Keplerian longitude of periastron precession rate omdot,
diff --git a/src/pint/erfautils.py b/src/pint/erfautils.py
index 3e9070d05..9267b4670 100644
--- a/src/pint/erfautils.py
+++ b/src/pint/erfautils.py
@@ -61,17 +61,13 @@ def gcrs_posvel_from_itrf(loc, toas, obsname="obs"):
unpack = True
else:
ttoas = toas
+ elif np.isscalar(toas):
+ ttoas = Time([toas], format="mjd")
+ unpack = True
else:
- if np.isscalar(toas):
- ttoas = Time([toas], format="mjd")
- unpack = True
- else:
- ttoas = toas
+ ttoas = toas
t = ttoas
pos, vel = loc.get_gcrs_posvel(t)
r = PosVel(pos.xyz, vel.xyz, obj=obsname, origin="earth")
- if unpack:
- return r[0]
- else:
- return r
+ return r[0] if unpack else r
diff --git a/src/pint/event_toas.py b/src/pint/event_toas.py
index f7ddd5391..a2768d4c9 100644
--- a/src/pint/event_toas.py
+++ b/src/pint/event_toas.py
@@ -1,13 +1,55 @@
-"""Generic function to load TOAs from events files."""
+"""Generic functions to load TOAs from events files, along with specific implementations for different missions.
+The versions that look like ``get_..._TOAs()`` are preferred: the others are retained for backward compatibility.
+
+**Instrument-specific Functions**
+
+.. autofunction:: pint.event_toas.get_NuSTAR_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.get_NICER_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.get_RXTE_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.get_IXPE_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.get_Swift_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.get_XMM_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.load_NuSTAR_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.load_NICER_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.load_RXTE_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.load_IXPE_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.load_Swift_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+.. autofunction:: pint.event_toas.load_XMM_TOAs(eventname [, minmjd, maxmjd, errors, ephem, planets])
+
+"""
import os
+from functools import partial
import astropy.io.fits as pyfits
+from astropy.time import Time
+from astropy.coordinates import EarthLocation
+from astropy import units as u
import numpy as np
from loguru import logger as log
import pint.toa as toa
from pint.fits_utils import read_fits_event_mjds_tuples
+"""
+Default TOA (event) uncertainty depending on facility
+
+* RXTE: https://ui.adsabs.harvard.edu/abs/1998ApJ...501..749R/abstract
+* IXPE: https://ui.adsabs.harvard.edu/abs/2019SPIE11118E..0VO/abstract
+* XMM: https://ui.adsabs.harvard.edu/abs/2012A%26A...545A.126M/abstract
+* NuSTAR: https://ui.adsabs.harvard.edu/abs/2021ApJ...908..184B/abstract
+* Swift: https://ui.adsabs.harvard.edu/abs/2005SPIE.5898..377C/abstract
+* NICER: https://heasarc.gsfc.nasa.gov/docs/nicer/mission_guide/
+"""
+_default_uncertainty = {
+ "NICER": 0.1 * u.us,
+ "RXTE": 2.5 * u.us,
+ "IXPE": 20 * u.us,
+ "XMM": 48 * u.us,
+ "NuSTAR": 65 * u.us,
+ "Swift": 300 * u.us,
+ "default": 1 * u.us,
+}
+
__all__ = [
"load_fits_TOAs",
@@ -18,6 +60,14 @@
"load_RXTE_TOAs",
"load_Swift_TOAs",
"load_XMM_TOAs",
+ "get_fits_TOAs",
+ "get_event_TOAs",
+ "get_NuSTAR_TOAs",
+ "get_NICER_TOAs",
+ "get_IXPE_TOAs",
+ "get_RXTE_TOAs",
+ "get_Swift_TOAs",
+ "get_XMM_TOAs",
]
@@ -33,7 +83,7 @@ def read_mission_info_from_heasoft():
db = {}
with open(fname) as fobj:
- for line in fobj.readlines():
+ for line in fobj:
line = line.strip()
if line.startswith("!") or line == "":
@@ -100,8 +150,9 @@ def create_mission_config():
try:
mission_config["chandra"] = mission_config["axaf"]
except KeyError:
- log.warning("AXAF configuration not found -- likely HEADAS envariable not set.")
- pass
+ log.warning(
+ "AXAF configuration not found -- likely HEADAS env variable not set."
+ )
# Fix xte
mission_config["xte"]["fits_columns"] = {"ecol": "PHA"}
@@ -199,9 +250,84 @@ def load_fits_TOAs(
timeref=None,
minmjd=-np.inf,
maxmjd=np.inf,
+ errors=_default_uncertainty,
+):
+ """
+ Read photon event times out of a FITS file as a list of PINT :class:`~pint.toa.TOA` objects.
+
+ Correctly handles raw event files, or ones processed with axBary to have
+ barycentered TOAs. Different conditions may apply to different missions.
+
+ The minmjd/maxmjd parameters can be used to avoid instantiation of TOAs
+ we don't want, which can otherwise be very slow.
+
+ Parameters
+ ----------
+ eventname : str
+ File name of the FITS event list
+ mission : str
+ Name of the mission (e.g. RXTE, XMM)
+ weights : array or None
+ The array has to be of the same size as the event list. Overwrites
+ possible weight lists from mission-specific FITS files
+ extension : str
+ FITS extension to read
+ timesys : str, default None
+ Force this time system
+ timeref : str, default None
+ Forse this time reference
+ minmjd : float, default "-infinity"
+ minimum MJD timestamp to return
+ maxmjd : float, default "infinity"
+ maximum MJD timestamp to return
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
+
+ Returns
+ -------
+ toalist : list of :class:`~pint.toa.TOA` objects
+
+ Note
+ ----
+ This list should be converted into a :class:`~pint.toa.TOAs` object with :func:`pint.toa.get_TOAs_list` for most operations
+
+ See Also
+ --------
+ :func:`get_fits_TOAs`
+ """
+ toas = get_fits_TOAs(
+ eventname,
+ mission,
+ weights=weights,
+ extension=extension,
+ timesys=timesys,
+ timeref=timeref,
+ minmjd=minmjd,
+ maxmjd=maxmjd,
+ errors=errors,
+ )
+
+ return toas.to_TOA_list()
+
+
+def get_fits_TOAs(
+ eventname,
+ mission,
+ weights=None,
+ extension=None,
+ timesys=None,
+ timeref=None,
+ minmjd=-np.inf,
+ maxmjd=np.inf,
+ ephem=None,
+ planets=False,
+ include_bipm=False,
+ include_gps=False,
+ errors=_default_uncertainty["default"],
):
"""
- Read photon event times out of a FITS file as PINT TOA objects.
+ Read photon event times out of a FITS file as :class:`pint.toa.TOAs` object
Correctly handles raw event files, or ones processed with axBary to have
barycentered TOAs. Different conditions may apply to different missions.
@@ -228,10 +354,23 @@ def load_fits_TOAs(
minimum MJD timestamp to return
maxmjd : float, default "infinity"
maximum MJD timestamp to return
+ ephem : str, optional
+ The name of the solar system ephemeris to use; defaults to "DE421".
+ planets : bool, optional
+ Whether to apply Shapiro delays based on planet positions. Note that a
+ long-standing TEMPO2 bug in this feature went unnoticed for years.
+ Defaults to False.
+ include_bipm : bool, optional
+ Use TT(BIPM) instead of TT(TAI)
+ include_gps : bool, optional
+ Apply GPS to UTC clock corrections
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
Returns
-------
- toalist : list of TOA objects
+ pint.toa.TOAs
"""
# Load photon times from event file
hdulist = pyfits.open(eventname)
@@ -245,8 +384,7 @@ def load_fits_TOAs(
and hdulist[1].name not in extension.split(",")
):
raise RuntimeError(
- "First table in FITS file"
- + "must be {}. Found {}".format(extension, hdulist[1].name)
+ f"First table in FITS file must be {extension}. Found {hdulist[1].name}"
)
if isinstance(extension, int) and extension != 1:
raise ValueError(
@@ -257,11 +395,12 @@ def load_fits_TOAs(
timesys = _get_timesys(hdulist[1])
if timeref is None:
timeref = _get_timeref(hdulist[1])
+ log.info(f"TIMESYS: {timesys} TIMEREF: {timeref}")
check_timesys(timesys)
check_timeref(timeref)
if not mission_config[mission]["allow_local"] and timesys != "TDB":
- log.error("Raw spacecraft TOAs not yet supported for " + mission)
+ log.error(f"Raw spacecraft TOAs not yet supported for {mission}")
obs, scale = _default_obs_and_scale(mission, timesys, timeref)
@@ -277,6 +416,9 @@ def load_fits_TOAs(
if weights is not None:
new_kwargs["weights"] = weights
+ if not isinstance(errors, u.Quantity):
+ errors = errors * u.microsecond
+
# mask out times/columns outside of mjd range
mjds_float = np.asarray([r[0] + r[1] for r in mjds])
idx = (minmjd < mjds_float) & (mjds_float < maxmjd)
@@ -284,20 +426,45 @@ def load_fits_TOAs(
for key in new_kwargs.keys():
new_kwargs[key] = new_kwargs[key][idx]
- toalist = [None] * len(mjds)
- kw = {}
+ location = EarthLocation(0, 0, 0) if timeref == "GEOCENTRIC" else None
+
+ if len(mjds.shape) == 2:
+ t = Time(
+ val=mjds[:, 0],
+ val2=mjds[:, 1],
+ format="mjd",
+ scale=scale,
+ location=location,
+ )
+ else:
+ t = Time(mjds, format="mjd", scale=scale, location=location)
+ flags = [toa.FlagDict() for _ in range(len(mjds))]
for i in range(len(mjds)):
- # Create TOA list
for key in new_kwargs:
- kw[key] = str(new_kwargs[key][i])
- toalist[i] = toa.TOA(mjds[i], obs=obs, scale=scale, **kw)
-
- return toalist
+ flags[i][key] = str(new_kwargs[key][i])
+
+ return toa.get_TOAs_array(
+ t,
+ obs,
+ include_gps=include_gps,
+ include_bipm=include_bipm,
+ planets=planets,
+ ephem=ephem,
+ flags=flags,
+ errors=errors,
+ )
-def load_event_TOAs(eventname, mission, weights=None, minmjd=-np.inf, maxmjd=np.inf):
+def load_event_TOAs(
+ eventname,
+ mission,
+ weights=None,
+ minmjd=-np.inf,
+ maxmjd=np.inf,
+ errors=_default_uncertainty["default"],
+):
"""
- Read photon event times out of a FITS file as PINT TOA objects.
+ Read photon event times out of a FITS file as PINT :class:`~pint.toa.TOA` objects.
Correctly handles raw event files, or ones processed with axBary to have
barycentered TOAs. Different conditions may apply to different missions.
@@ -318,10 +485,21 @@ def load_event_TOAs(eventname, mission, weights=None, minmjd=-np.inf, maxmjd=np.
minimum MJD timestamp to return
maxmjd : float, default "infinity"
maximum MJD timestamp to return
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
Returns
-------
- toalist : list of TOA objects
+ toalist : list of :class:`~pint.toa.TOA` objects
+
+ Note
+ ----
+ This list should be converted into a :class:`~pint.toa.TOAs` object with :func:`pint.toa.get_TOAs_list` for most operations
+
+ See Also
+ --------
+ :func:`get_event_TOAs`
"""
# Load photon times from event file
@@ -337,28 +515,214 @@ def load_event_TOAs(eventname, mission, weights=None, minmjd=-np.inf, maxmjd=np.
extension=extension,
minmjd=minmjd,
maxmjd=maxmjd,
+ errors=errors,
)
-def load_RXTE_TOAs(eventname, minmjd=-np.inf, maxmjd=np.inf):
- return load_event_TOAs(eventname, "xte", minmjd=minmjd, maxmjd=maxmjd)
+def get_event_TOAs(
+ eventname,
+ mission,
+ weights=None,
+ minmjd=-np.inf,
+ maxmjd=np.inf,
+ ephem=None,
+ planets=False,
+ include_bipm=False,
+ include_gps=False,
+ errors=_default_uncertainty["default"],
+):
+ """
+ Read photon event times out of a FITS file as a :class:`pint.toa.TOAs` object
+ Correctly handles raw event files, or ones processed with axBary to have
+ barycentered TOAs. Different conditions may apply to different missions.
-def load_NICER_TOAs(eventname, minmjd=-np.inf, maxmjd=np.inf):
- return load_event_TOAs(eventname, "nicer", minmjd=minmjd, maxmjd=maxmjd)
+ The minmjd/maxmjd parameters can be used to avoid instantiation of TOAs
+ we don't want, which can otherwise be very slow.
+ Parameters
+ ----------
+ eventname : str
+ File name of the FITS event list
+ mission : str
+ Name of the mission (e.g. RXTE, XMM)
+ weights : array or None
+ The array has to be of the same size as the event list. Overwrites
+ possible weight lists from mission-specific FITS files
+ minmjd : float, default "-infinity"
+ minimum MJD timestamp to return
+ maxmjd : float, default "infinity"
+ maximum MJD timestamp to return
+ ephem : str, optional
+ The name of the solar system ephemeris to use; defaults to "DE421".
+ planets : bool, optional
+ Whether to apply Shapiro delays based on planet positions. Note that a
+ long-standing TEMPO2 bug in this feature went unnoticed for years.
+ Defaults to False.
+ include_bipm : bool, optional
+ Use TT(BIPM) instead of TT(TAI)
+ include_gps : bool, optional
+ Apply GPS to UTC clock corrections
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
-def load_IXPE_TOAs(eventname, minmjd=-np.inf, maxmjd=np.inf):
- return load_event_TOAs(eventname, "ixpe", minmjd=minmjd, maxmjd=maxmjd)
+ Returns
+ -------
+ pint.toa.TOAs
+ """
+ # Load photon times from event file
-def load_XMM_TOAs(eventname, minmjd=-np.inf, maxmjd=np.inf):
- return load_event_TOAs(eventname, "xmm", minmjd=minmjd, maxmjd=maxmjd)
+ try:
+ extension = mission_config[mission]["fits_extension"]
+ except ValueError:
+ log.warning("Mission name (TELESCOP) not recognized, using generic!")
+ extension = mission_config["generic"]["fits_extension"]
+ return get_fits_TOAs(
+ eventname,
+ mission,
+ weights=weights,
+ extension=extension,
+ minmjd=minmjd,
+ maxmjd=maxmjd,
+ ephem=ephem,
+ planets=planets,
+ include_bipm=include_bipm,
+ include_gps=include_gps,
+ errors=errors,
+ )
-def load_NuSTAR_TOAs(eventname, minmjd=-np.inf, maxmjd=np.inf):
- return load_event_TOAs(eventname, "nustar", minmjd=minmjd, maxmjd=maxmjd)
+# generic docstring for these functions
+_load_event_docstring = """
+ Read photon event times out of a {} file as PINT :class:`~pint.toa.TOA` objects.
+ Correctly handles raw event files, or ones processed with axBary to have
+ barycentered TOAs. Different conditions may apply to different missions.
+
+ The minmjd/maxmjd parameters can be used to avoid instantiation of TOAs
+ we don't want, which can otherwise be very slow.
+
+ Parameters
+ ----------
+ eventname : str
+ File name of the FITS event list
+ minmjd : float, default "-infinity"
+ minimum MJD timestamp to return
+ maxmjd : float, default "infinity"
+ maximum MJD timestamp to return
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
+
+ Returns
+ -------
+ toalist : list of :class:`~pint.toa.TOA` objects
+
+ Note
+ ----
+ This list should be converted into a :class:`~pint.toa.TOAs` object with :func:`pint.toa.get_TOAs_list` for most operations
+
+ See Also
+ --------
+ :func:`get_{}_TOAs`
+ :func:`load_event_TOAs`
+ """
+
+load_RXTE_TOAs = partial(
+ load_event_TOAs, mission="xte", errors=_default_uncertainty["RXTE"]
+)
+load_RXTE_TOAs.__doc__ = _load_event_docstring.format("RXTE", "RXTE")
+
+load_NICER_TOAs = partial(
+ load_event_TOAs, mission="nicer", errors=_default_uncertainty["NICER"]
+)
+load_NICER_TOAs.__doc__ = _load_event_docstring.format("NICER", "NICER")
+
+load_IXPE_TOAs = partial(
+ load_event_TOAs, mission="ixpe", errors=_default_uncertainty["IXPE"]
+)
+load_IXPE_TOAs.__doc__ = _load_event_docstring.format("IXPE", "IXPE")
+
+load_XMM_TOAs = partial(
+ load_event_TOAs, mission="xmm", errors=_default_uncertainty["XMM"]
+)
+load_XMM_TOAs.__doc__ = _load_event_docstring.format("XMM", "XMM")
+
+load_NuSTAR_TOAs = partial(
+ load_event_TOAs, mission="nustar", errors=_default_uncertainty["NuSTAR"]
+)
+load_NuSTAR_TOAs.__doc__ = _load_event_docstring.format("NuSTAR", "NuSTAR")
+
+load_Swift_TOAs = partial(
+ load_event_TOAs, mission="swift", errors=_default_uncertainty["Swift"]
+)
+load_Swift_TOAs.__doc__ = _load_event_docstring.format("Swift", "Swift")
+
+# generic docstring for these functions
+_get_event_docstring = """
+ Read photon event times out of a {} file as a :class:`pint.toa.TOAs` object
+
+ Correctly handles raw event files, or ones processed with axBary to have
+ barycentered TOAs. Different conditions may apply to different missions.
+
+ The minmjd/maxmjd parameters can be used to avoid instantiation of TOAs
+ we don't want, which can otherwise be very slow.
+
+ Parameters
+ ----------
+ eventname : str
+ File name of the FITS event list
+ minmjd : float, default "-infinity"
+ minimum MJD timestamp to return
+ maxmjd : float, default "infinity"
+ maximum MJD timestamp to return
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
+ ephem : str, optional
+ The name of the solar system ephemeris to use; defaults to "DE421".
+ planets : bool, optional
+ Whether to apply Shapiro delays based on planet positions. Note that a
+ long-standing TEMPO2 bug in this feature went unnoticed for years.
+ Defaults to False.
+
+ Returns
+ -------
+ pint.toa.TOAs
+
+ See Also
+ --------
+ :func:`get_event_TOAs`
+ """
-def load_Swift_TOAs(eventname, minmjd=-np.inf, maxmjd=np.inf):
- return load_event_TOAs(eventname, "swift", minmjd=minmjd, maxmjd=maxmjd)
+get_RXTE_TOAs = partial(
+ get_event_TOAs, mission="xte", errors=_default_uncertainty["RXTE"]
+)
+get_RXTE_TOAs.__doc__ = _get_event_docstring.format("RXTE")
+
+get_NICER_TOAs = partial(
+ get_event_TOAs, mission="nicer", errors=_default_uncertainty["NICER"]
+)
+get_NICER_TOAs.__doc__ = _get_event_docstring.format("NICER")
+
+get_IXPE_TOAs = partial(
+ get_event_TOAs, mission="ixpe", errors=_default_uncertainty["IXPE"]
+)
+get_IXPE_TOAs.__doc__ = _get_event_docstring.format("IXPE")
+
+get_XMM_TOAs = partial(
+ get_event_TOAs, mission="xmm", errors=_default_uncertainty["XMM"]
+)
+get_XMM_TOAs.__doc__ = _get_event_docstring.format("XMM")
+
+get_NuSTAR_TOAs = partial(
+ get_event_TOAs, mission="nustar", errors=_default_uncertainty["NuSTAR"]
+)
+get_NuSTAR_TOAs.__doc__ = _get_event_docstring.format("NuSTAR")
+
+get_Swift_TOAs = partial(
+ get_event_TOAs, mission="swift", errors=_default_uncertainty["Swift"]
+)
+get_Swift_TOAs.__doc__ = _get_event_docstring.format("Swift")
diff --git a/src/pint/eventstats.py b/src/pint/eventstats.py
index be4e03f6a..b24c39aae 100644
--- a/src/pint/eventstats.py
+++ b/src/pint/eventstats.py
@@ -39,15 +39,11 @@ def vec(func):
def to_array(x, dtype=None):
x = np.asarray(x, dtype=dtype)
- if len(x.shape) == 0:
- return np.asarray([x])
- return x
+ return np.asarray([x]) if len(x.shape) == 0 else x
def from_array(x):
- if (len(x.shape) == 1) and (x.shape[0] == 1):
- return x[0]
- return x
+ return x[0] if (len(x.shape) == 1) and (x.shape[0] == 1) else x
def sig2sigma(sig, two_tailed=True, logprob=False):
@@ -86,9 +82,8 @@ def sig2sigma(sig, two_tailed=True, logprob=False):
if logprob:
if np.any(logsig > 0):
raise ValueError("Probability must be between 0 and 1.")
- else:
- if np.any((sig > 1) | (sig <= 0)):
- raise ValueError("Probability must be between 0 and 1.")
+ elif np.any((sig > 1) | (sig <= 0)):
+ raise ValueError("Probability must be between 0 and 1.")
if not two_tailed:
sig *= 2
@@ -123,21 +118,16 @@ def sigma2sig(sigma, two_tailed=True):
"""
# this appears to handle up to machine precision with no problem
- if two_tailed:
- return erfc(sigma / 2**0.5)
- return 1 - 0.5 * erfc(-sigma / 2**0.5)
+ return erfc(sigma / 2**0.5) if two_tailed else 1 - 0.5 * erfc(-sigma / 2**0.5)
def sigma_trials(sigma, trials):
# correct a sigmal value for a trials factor
- if sigma < 20:
- p = sigma2sig(sigma) * trials
- if p >= 1:
- return 0
- return sig2sigma(p)
- else:
- # use an asymptotic expansion -- this needs to be checked!
+ # use an asymptotic expansion -- this needs to be checked!
+ if sigma >= 20:
return (sigma**2 - 2 * np.log(trials)) ** 0.5
+ p = sigma2sig(sigma) * trials
+ return 0 if p >= 1 else sig2sigma(p)
def z2m(phases, m=2):
@@ -312,11 +302,11 @@ def sf_hm(h, m=20, c=4, logprob=False):
# next, develop the integrals in the power series
alpha = 0.5 * exp(-0.5 * c)
- if not logprob:
- return exp(-0.5 * h) * (alpha ** arange(0, m) * ints).sum()
- else:
- # NB -- this has NOT been tested for partial underflow
- return -0.5 * h + np.log((alpha ** arange(0, m) * ints).sum())
+ return (
+ -0.5 * h + np.log((alpha ** arange(0, m) * ints).sum())
+ if logprob
+ else exp(-0.5 * h) * (alpha ** arange(0, m) * ints).sum()
+ )
def h2sig(h):
@@ -330,9 +320,7 @@ def sf_h20_dj1989(h):
formula of de Jager et al. 1989 -- NB the quadratic term is NOT correct."""
if h <= 23:
return 0.9999755 * np.exp(-0.39802 * h)
- if h > 50:
- return 4e-8
- return 1.210597 * np.exp(-0.45901 * h + 0.0022900 * h**2)
+ return 4e-8 if h > 50 else 1.210597 * np.exp(-0.45901 * h + 0.0022900 * h**2)
def sf_h20_dj2010(h):
@@ -352,8 +340,6 @@ def sf_stackedh(k, h, l=0.398405):
de Jager & Busching 2010."""
fact = lambda x: gamma(x + 1)
- p = 0
c = l * h
- for i in range(k):
- p += c**i / fact(i)
+ p = sum(c**i / fact(i) for i in range(k))
return p * np.exp(-c)
diff --git a/src/pint/fermi_toas.py b/src/pint/fermi_toas.py
index c374e57ac..00b2f8671 100644
--- a/src/pint/fermi_toas.py
+++ b/src/pint/fermi_toas.py
@@ -11,8 +11,10 @@
from pint.fits_utils import read_fits_event_mjds_tuples
from pint.observatory import get_observatory
+# default TOA (event) uncertainty depending on facility
+_default_uncertainty = 1 * u.us
-__all__ = ["load_Fermi_TOAs"]
+__all__ = ["load_Fermi_TOAs", "get_Fermi_TOAs"]
def calc_lat_weights(energies, angseps, logeref=4.1, logesig=0.5):
@@ -54,8 +56,7 @@ def calc_lat_weights(energies, angseps, logeref=4.1, logesig=0.5):
sigma = (
np.sqrt(
- psfpar0 * psfpar0 * np.power(100.0 / energies, 2.0 * psfpar1)
- + psfpar2 * psfpar2
+ (psfpar0**2 * np.power(100.0 / energies, 2.0 * psfpar1) + psfpar2**2)
)
/ scalepsf
)
@@ -77,13 +78,13 @@ def load_Fermi_TOAs(
minmjd=-np.inf,
maxmjd=np.inf,
fermiobs="Fermi",
+ errors=_default_uncertainty,
):
"""
- toalist = load_Fermi_TOAs(ft1name)
- Read photon event times out of a Fermi FT1 file and return
- a list of PINT TOA objects.
- Correctly handles raw FT1 files, or ones processed with gtbary
- to have barycentered or geocentered TOAs.
+ Read photon event times out of a Fermi FT1 file and return a list of PINT :class:`~pint.toa.TOA` objects.
+
+ Correctly handles raw FT1 files, or ones processed with gtbary
+ to have barycentered or geocentered TOAs.
Parameters
@@ -107,116 +108,38 @@ def load_Fermi_TOAs(
fermiobs: str
The default observatory name is Fermi, and must have already been
registered. The user can specify another name
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
Returns
-------
toalist : list
- A list of TOA objects corresponding to the Fermi events.
- """
-
- # Load photon times from FT1 file
- hdulist = fits.open(ft1name)
- ft1hdr = hdulist[1].header
- ft1dat = hdulist[1].data
-
- # TIMESYS will be 'TT' for unmodified Fermi LAT events (or geocentered), and
- # 'TDB' for events barycentered with gtbary
- # TIMEREF will be 'GEOCENTER' for geocentered events,
- # 'SOLARSYSTEM' for barycentered,
- # and 'LOCAL' for unmodified events
-
- timesys = ft1hdr["TIMESYS"]
- log.info("TIMESYS {0}".format(timesys))
- timeref = ft1hdr["TIMEREF"]
- log.info("TIMEREF {0}".format(timeref))
+ A list of :class:`~pint.toa.TOA` objects corresponding to the Fermi events.
- # Read time column from FITS file
- mjds = read_fits_event_mjds_tuples(hdulist[1])
- if len(mjds) == 0:
- log.error("No MJDs read from file!")
- raise
+ Note
+ ----
+ This list should be converted into a :class:`~pint.toa.TOAs` object with
+ :func:`pint.toa.get_TOAs_list` for most operations
- energies = ft1dat.field("ENERGY") * u.MeV
- if weightcolumn is not None:
- if weightcolumn == "CALC":
- photoncoords = SkyCoord(
- ft1dat.field("RA") * u.degree,
- ft1dat.field("DEC") * u.degree,
- frame="icrs",
- )
- weights = calc_lat_weights(
- ft1dat.field("ENERGY"),
- photoncoords.separation(targetcoord),
- logeref=logeref,
- logesig=logesig,
- )
- else:
- weights = ft1dat.field(weightcolumn)
- if minweight > 0.0:
- idx = np.where(weights > minweight)[0]
- mjds = mjds[idx]
- energies = energies[idx]
- weights = weights[idx]
+ See Also
+ --------
+ :func:`get_Fermi_TOAs`
- # limit the TOAs to ones in selected MJD range
- mjds_float = np.asarray([r[0] + r[1] for r in mjds])
- idx = (minmjd < mjds_float) & (mjds_float < maxmjd)
- mjds = mjds[idx]
- energies = energies[idx]
- if weightcolumn is not None:
- weights = weights[idx]
-
- if timesys == "TDB":
- log.info("Building barycentered TOAs")
- obs = "Barycenter"
- scale = "tdb"
- msg = "barycentric"
- elif (timesys == "TT") and (timeref == "LOCAL"):
- assert timesys == "TT"
- try:
- get_observatory(fermiobs)
- except KeyError:
- log.error(
- "%s observatory not defined. Make sure you have specified an FT2 file!"
- % fermiobs
- )
- raise
- obs = fermiobs
- scale = "tt"
- msg = "spacecraft local"
- elif (timesys == "TT") and (timeref == "GEOCENTRIC"):
- obs = "Geocenter"
- scale = "tt"
- msg = "geocentric"
- else:
- raise ValueError("Unrecognized TIMEREF/TIMESYS.")
-
- log.info(
- "Building {0} TOAs, with MJDs in range {1} to {2}".format(
- msg, mjds[0, 0] + mjds[0, 1], mjds[-1, 0] + mjds[-1, 1]
- )
+ """
+ t = get_Fermi_TOAs(
+ ft1name,
+ weightcolumn=weightcolumn,
+ targetcoord=targetcoord,
+ logeref=logeref,
+ logesig=logesig,
+ minweight=minweight,
+ minmjd=minmjd,
+ maxmjd=maxmjd,
+ fermiobs=fermiobs,
+ errors=errors,
)
- if weightcolumn is None:
- toalist = [
- toa.TOA(
- m, obs=obs, scale=scale, energy=str(e.to_value(u.MeV)), error=1.0 * u.us
- )
- for m, e in zip(mjds, energies)
- ]
- else:
- toalist = [
- toa.TOA(
- m,
- obs=obs,
- scale=scale,
- energy=str(e.to_value(u.MeV)),
- weight=str(w),
- error=1.0 * u.us,
- )
- for m, e, w in zip(mjds, energies, weights)
- ]
-
- return toalist
+ return t.to_TOA_list()
def get_Fermi_TOAs(
@@ -231,6 +154,9 @@ def get_Fermi_TOAs(
fermiobs="Fermi",
ephem=None,
planets=False,
+ include_bipm=False,
+ include_gps=False,
+ errors=_default_uncertainty,
):
"""
Read photon event times out of a Fermi FT1 file and return a :class:`pint.toa.TOAs` object
@@ -265,7 +191,13 @@ def get_Fermi_TOAs(
Whether to apply Shapiro delays based on planet positions. Note that a
long-standing TEMPO2 bug in this feature went unnoticed for years.
Defaults to False.
-
+ include_bipm : bool, optional
+ Use TT(BIPM) instead of TT(TAI)
+ include_gps : bool, optional
+ Apply GPS to UTC clock corrections
+ errors : astropy.units.Quantity or float, optional
+ The uncertainty on the TOA; if it's a float it is assumed to be
+ in microseconds
Returns
-------
@@ -328,6 +260,9 @@ def get_Fermi_TOAs(
energies = energies[idx]
weights = weights[idx]
+ if not isinstance(errors, u.Quantity):
+ errors = errors * u.microsecond
+
# limit the TOAs to ones in selected MJD range
mjds_float = np.asarray([r[0] + r[1] for r in mjds])
idx = (minmjd < mjds_float) & (mjds_float < maxmjd)
@@ -341,23 +276,25 @@ def get_Fermi_TOAs(
obs = "Barycenter"
scale = "tdb"
msg = "barycentric"
+ location = None
elif (timesys == "TT") and (timeref == "LOCAL"):
assert timesys == "TT"
try:
get_observatory(fermiobs)
except KeyError:
log.error(
- "%s observatory not defined. Make sure you have specified an FT2 file!"
- % fermiobs
+ f"{fermiobs} observatory not defined. Make sure you have specified an FT2 file!"
)
raise
obs = fermiobs
scale = "tt"
msg = "spacecraft local"
+ location = None
elif (timesys == "TT") and (timeref == "GEOCENTRIC"):
obs = "Geocenter"
scale = "tt"
msg = "geocentric"
+ location = EarthLocation(0, 0, 0)
else:
raise ValueError("Unrecognized TIMEREF/TIMESYS.")
@@ -370,31 +307,26 @@ def get_Fermi_TOAs(
val2=mjds[:, 1],
format="mjd",
scale=scale,
- location=EarthLocation(0, 0, 0),
+ location=location,
)
else:
- t = Time(
- mjds,
- format="mjd",
- scale=scale,
- location=EarthLocation(0, 0, 0),
- )
+ t = Time(mjds, format="mjd", scale=scale, location=location)
if weightcolumn is None:
return toa.get_TOAs_array(
t,
obs,
- include_gps=False,
- include_bipm=False,
+ errors=errors,
+ include_gps=include_gps,
+ include_bipm=include_bipm,
planets=planets,
ephem=ephem,
- flags=[
- {"energy": str(e), "weight": str(w)} for e in energies.to_value(u.MeV)
- ],
+ flags=[{"energy": str(e)} for e in energies.to_value(u.MeV)],
)
else:
return toa.get_TOAs_array(
t,
obs,
+ errors=errors,
include_gps=False,
include_bipm=False,
planets=planets,
diff --git a/src/pint/fits_utils.py b/src/pint/fits_utils.py
index 4012aaecf..61ab27d40 100644
--- a/src/pint/fits_utils.py
+++ b/src/pint/fits_utils.py
@@ -12,7 +12,7 @@
def read_fits_event_mjds_tuples(event_hdu, timecolumn="TIME"):
- """Read a set of MJDs from a FITS HDU, with proper converstion of times to MJD
+ """Read a set of MJDs from a FITS HDU, with proper conversion of times to MJD
The FITS time format is defined here:
https://heasarc.gsfc.nasa.gov/docs/journal/timing3.html
@@ -58,18 +58,13 @@ def read_fits_event_mjds_tuples(event_hdu, timecolumn="TIME"):
)
log.debug("MJDREF = {0}".format(MJDREF))
- # Should check timecolumn units to be sure they are seconds!
-
- # MJD = (TIMECOLUMN + TIMEZERO)/SECS_PER_DAY + MJDREF
- mjds = np.array(
+ return np.array(
[(MJDREF, tt) for tt in (event_dat.field(timecolumn) + TIMEZERO) / SECS_PER_DAY]
)
- return mjds
-
def read_fits_event_mjds(event_hdu, timecolumn="TIME"):
- """Read a set of MJDs from a FITS HDU, with proper converstion of times to MJD
+ """Read a set of MJDs from a FITS HDU, with proper conversion of times to MJD
The FITS time format is defined here:
https://heasarc.gsfc.nasa.gov/docs/journal/timing3.html
@@ -107,8 +102,6 @@ def read_fits_event_mjds(event_hdu, timecolumn="TIME"):
# Should check timecolumn units to be sure they are seconds!
# MJD = (TIMECOLUMN + TIMEZERO)/SECS_PER_DAY + MJDREF
- mjds = (
+ return (
np.array(event_dat.field(timecolumn), dtype=float) + TIMEZERO
) / SECS_PER_DAY + MJDREF
-
- return mjds
diff --git a/src/pint/fitter.py b/src/pint/fitter.py
index db40144d0..22e9860b3 100644
--- a/src/pint/fitter.py
+++ b/src/pint/fitter.py
@@ -57,6 +57,8 @@
>>> fitter = Fitter.auto(toas, model)
"""
+
+import contextlib
import copy
from warnings import warn
@@ -234,11 +236,12 @@ def __init__(self, toas, model, track_mode=None, residuals=None):
@classmethod
def auto(
- self, toas, model, downhill=True, track_mode=None, residuals=None, **kwargs
+ cls, toas, model, downhill=True, track_mode=None, residuals=None, **kwargs
):
"""Automatically return the proper :class:`pint.fitter.Fitter` object depending on the TOAs and model.
- In general the `downhill` fitters are to be preferred. See https://github.com/nanograv/PINT/wiki/How-To#choose-a-fitter for the logic used.
+ In general the `downhill` fitters are to be preferred.
+ See https://github.com/nanograv/PINT/wiki/How-To#choose-a-fitter for the logic used.
Parameters
----------
@@ -263,68 +266,66 @@ def auto(
if toas.wideband:
if downhill:
log.info(
- f"For wideband TOAs and downhill fitter, returning 'WidebandDownhillFitter'"
+ "For wideband TOAs and downhill fitter, returning 'WidebandDownhillFitter'"
)
return WidebandDownhillFitter(
toas, model, track_mode=track_mode, residuals=residuals, **kwargs
)
else:
log.info(
- f"For wideband TOAs and non-downhill fitter, returning 'WidebandTOAFitter'"
+ "For wideband TOAs and non-downhill fitter, returning 'WidebandTOAFitter'"
)
return WidebandTOAFitter(toas, model, track_mode=track_mode, **kwargs)
- else:
- if model.has_correlated_errors:
- if downhill:
- log.info(
- f"For narrowband TOAs with correlated errors and downhill fitter, returning 'DownhillGLSFitter'"
- )
- return DownhillGLSFitter(
- toas,
- model,
- track_mode=track_mode,
- residuals=residuals,
- **kwargs,
- )
- else:
- log.info(
- f"For narrowband TOAs with correlated errors and non-downhill fitter, returning 'GLSFitter'"
- )
- return GLSFitter(
- toas,
- model,
- track_mode=track_mode,
- residuals=residuals,
- **kwargs,
- )
+ elif model.has_correlated_errors:
+ if downhill:
+ log.info(
+ "For narrowband TOAs with correlated errors and downhill fitter, returning 'DownhillGLSFitter'"
+ )
+ return DownhillGLSFitter(
+ toas,
+ model,
+ track_mode=track_mode,
+ residuals=residuals,
+ **kwargs,
+ )
else:
- if downhill:
- log.info(
- f"For narrowband TOAs without correlated errors and downhill fitter, returning 'DownhillWLSFitter'"
- )
- return DownhillWLSFitter(
- toas,
- model,
- track_mode=track_mode,
- residuals=residuals,
- **kwargs,
- )
- else:
- log.info(
- f"For narrowband TOAs without correlated errors and non-downhill fitter, returning 'WLSFitter'"
- )
- return WLSFitter(
- toas,
- model,
- track_mode=track_mode,
- residuals=residuals,
- **kwargs,
- )
+ log.info(
+ "For narrowband TOAs with correlated errors and non-downhill fitter, returning 'GLSFitter'"
+ )
+ return GLSFitter(
+ toas,
+ model,
+ track_mode=track_mode,
+ residuals=residuals,
+ **kwargs,
+ )
+ elif downhill:
+ log.info(
+ "For narrowband TOAs without correlated errors and downhill fitter, returning 'DownhillWLSFitter'"
+ )
+ return DownhillWLSFitter(
+ toas,
+ model,
+ track_mode=track_mode,
+ residuals=residuals,
+ **kwargs,
+ )
+ else:
+ log.info(
+ "For narrowband TOAs without correlated errors and non-downhill fitter, returning 'WLSFitter'"
+ )
+ return WLSFitter(
+ toas,
+ model,
+ track_mode=track_mode,
+ residuals=residuals,
+ **kwargs,
+ )
def fit_toas(self, maxiter=None, debug=False):
"""Run fitting operation.
- This method needs to be implemented by subclasses. All implemenations
+ This method needs to be implemented by subclasses. All implementations
should call ``self.model.validate()`` and
``self.model.validate_toas()`` before doing the fitting.
"""
@@ -345,7 +346,6 @@ def get_summary(self, nodmx=False):
"fit_toas() has not been run, so pre-fit and post-fit will be the same!"
)
- import uncertainties.umath as um
from uncertainties import ufloat
# Check if Wideband or not
@@ -363,7 +363,7 @@ def get_summary(self, nodmx=False):
# to handle all parameter names, determine the longest length for the first column
longestName = 0 # optionally specify the minimum length here instead of 0
- for pn in self.model.params_ordered:
+ for pn in self.model.params:
if nodmx and pn.startswith("DMX"):
continue
if len(pn) > longestName:
@@ -378,7 +378,7 @@ def get_summary(self, nodmx=False):
s += ("{:<" + spacingName + "s} {:>20s} {:>28s} {}\n").format(
"=" * longestName, "=" * 20, "=" * 28, "=" * 5
)
- for pn in self.model.params_ordered:
+ for pn in self.model.params:
if nodmx and pn.startswith("DMX"):
continue
prefitpar = getattr(self.model_init, pn)
@@ -398,10 +398,11 @@ def get_summary(self, nodmx=False):
pn, str(prefitpar.quantity), "", par.units
)
else:
- if par.units == u.hourangle:
- uncertainty_unit = pint.hourangle_second
- else:
- uncertainty_unit = u.arcsec
+ uncertainty_unit = (
+ pint.hourangle_second
+ if par.units == u.hourangle
+ else u.arcsec
+ )
s += (
"{:" + spacingName + "s} {:>20s} {:>16s} +/- {:.2g} \n"
).format(
@@ -414,45 +415,39 @@ def get_summary(self, nodmx=False):
s += ("{:" + spacingName + "s} {:>20s} {:28s} {}\n").format(
pn, prefitpar.str_quantity(prefitpar.value), "", par.units
)
- else:
- # Assume a numerical parameter
- if par.frozen:
- if par.name == "START":
- if prefitpar.value is None:
- s += (
- "{:" + spacingName + "s} {:20s} {:28g} {} \n"
- ).format(pn, " ", par.value, par.units)
- else:
- s += (
- "{:" + spacingName + "s} {:20g} {:28g} {} \n"
- ).format(pn, prefitpar.value, par.value, par.units)
- elif par.name == "FINISH":
- if prefitpar.value is None:
- s += (
- "{:" + spacingName + "s} {:20s} {:28g} {} \n"
- ).format(pn, " ", par.value, par.units)
- else:
- s += (
- "{:" + spacingName + "s} {:20g} {:28g} {} \n"
- ).format(pn, prefitpar.value, par.value, par.units)
- else:
- s += ("{:" + spacingName + "s} {:20g} {:28s} {} \n").format(
- pn, prefitpar.value, "", par.units
- )
+ elif par.frozen:
+ if (
+ par.name == "START"
+ and prefitpar.value is None
+ or par.name != "START"
+ and par.name == "FINISH"
+ and prefitpar.value is None
+ ):
+ s += ("{:" + spacingName + "s} {:20s} {:28g} {} \n").format(
+ pn, " ", par.value, par.units
+ )
+ elif par.name in ["START", "FINISH"]:
+ s += ("{:" + spacingName + "s} {:20g} {:28g} {} \n").format(
+ pn, prefitpar.value, par.value, par.units
+ )
else:
- # s += "{:14s} {:20g} {:20g} {:20.2g} {} \n".format(
- # pn,
- # prefitpar.value,
- # par.value,
- # par.uncertainty.value,
- # par.units,
- # )
- s += ("{:" + spacingName + "s} {:20g} {:28SP} {} \n").format(
- pn,
- prefitpar.value,
- ufloat(par.value, par.uncertainty.value),
- par.units,
+ s += ("{:" + spacingName + "s} {:20g} {:28s} {} \n").format(
+ pn, prefitpar.value, "", par.units
)
+ else:
+ # s += "{:14s} {:20g} {:20g} {:20.2g} {} \n".format(
+ # pn,
+ # prefitpar.value,
+ # par.value,
+ # par.uncertainty.value,
+ # par.units,
+ # )
+ s += ("{:" + spacingName + "s} {:20g} {:28SP} {} \n").format(
+ pn,
+ prefitpar.value,
+ ufloat(par.value, par.uncertainty.value),
+ par.units,
+ )
s += "\n" + self.get_derived_params()
return s
@@ -468,45 +463,30 @@ def get_derived_params(self):
F0 = self.model.F0.quantity
if not self.model.F0.frozen:
p, perr = pint.derived_quantities.pferrs(F0, self.model.F0.uncertainty)
- s += "Period = {} +/- {}\n".format(p.to(u.s), perr.to(u.s))
+ s += f"Period = {p.to(u.s)} +/- {perr.to(u.s)}\n"
else:
- s += "Period = {}\n".format((1.0 / F0).to(u.s))
+ s += f"Period = {(1.0 / F0).to(u.s)}\n"
if hasattr(self.model, "F1"):
F1 = self.model.F1.quantity
if not any([self.model.F1.frozen, self.model.F0.frozen]):
p, perr, pd, pderr = pint.derived_quantities.pferrs(
F0, self.model.F0.uncertainty, F1, self.model.F1.uncertainty
)
- s += "Pdot = {} +/- {}\n".format(
- pd.to(u.dimensionless_unscaled), pderr.to(u.dimensionless_unscaled)
- )
+ s += f"Pdot = {pd.to(u.dimensionless_unscaled)} +/- {pderr.to(u.dimensionless_unscaled)}\n"
if F1.value < 0.0: # spinning-down
brakingindex = 3
- s += "Characteristic age = {:.4g} (braking index = {})\n".format(
- pint.derived_quantities.pulsar_age(F0, F1, n=brakingindex),
- brakingindex,
- )
- s += "Surface magnetic field = {:.3g}\n".format(
- pint.derived_quantities.pulsar_B(F0, F1)
- )
- s += "Magnetic field at light cylinder = {:.4g}\n".format(
- pint.derived_quantities.pulsar_B_lightcyl(F0, F1)
- )
+ s += f"Characteristic age = {pint.derived_quantities.pulsar_age(F0, F1, n=brakingindex):.4g} (braking index = {brakingindex})\n"
+ s += f"Surface magnetic field = {pint.derived_quantities.pulsar_B(F0, F1):.3g}\n"
+ s += f"Magnetic field at light cylinder = {pint.derived_quantities.pulsar_B_lightcyl(F0, F1):.4g}\n"
I_NS = I = 1.0e45 * u.g * u.cm**2
- s += "Spindown Edot = {:.4g} (I={})\n".format(
- pint.derived_quantities.pulsar_edot(F0, F1, I=I_NS), I_NS
- )
+ s += f"Spindown Edot = {pint.derived_quantities.pulsar_edot(F0, F1, I=I_NS):.4g} (I={I_NS})\n"
else:
s += "Not computing Age, B, or Edot since F1 > 0.0\n"
- if hasattr(self.model, "PX"):
- if not self.model.PX.frozen:
- s += "\n"
- px = ufloat(
- self.model.PX.quantity.to(u.arcsec).value,
- self.model.PX.uncertainty.to(u.arcsec).value,
- )
- s += "Parallax distance = {:.3uP} pc\n".format(1.0 / px)
+ if hasattr(self.model, "PX") and not self.model.PX.frozen:
+ s += "\n"
+ px = self.model.PX.as_ufloat(u.arcsec)
+ s += f"Parallax distance = {1.0/px:.3uP} pc\n"
# Now binary system derived parameters
if self.model.is_binary:
@@ -514,7 +494,7 @@ def get_derived_params(self):
if x.startswith("Binary"):
binary = x
- s += "\nBinary model {}\n".format(binary)
+ s += f"\nBinary model {binary}\n"
btx = False
if (
@@ -525,14 +505,12 @@ def get_derived_params(self):
btx = True
FB0 = self.model.FB0.quantity
if not self.model.FB0.frozen:
- p, perr = pint.derived_quantities.pferrs(
+ pb, pberr = pint.derived_quantities.pferrs(
FB0, self.model.FB0.uncertainty
)
- s += "Orbital Period (PB) = {} +/- {}\n".format(
- p.to(u.d), perr.to(u.d)
- )
+ s += f"Orbital Period (PB) = {pb.to(u.d)} +/- {pberr.to(u.d)}\n"
else:
- s += "Orbital Period (PB) = {}\n".format((1.0 / FB0).to(u.d))
+ s += f"Orbital Period (PB) = {(1.0 / FB0).to(u.d)}\n"
if (
hasattr(self.model, "FB1")
@@ -541,53 +519,40 @@ def get_derived_params(self):
):
FB1 = self.model.FB1.quantity
if not any([self.model.FB1.frozen, self.model.FB0.frozen]):
- p, perr, pd, pderr = pint.derived_quantities.pferrs(
+ pb, pberr, pbd, pbderr = pint.derived_quantities.pferrs(
FB0, self.model.FB0.uncertainty, FB1, self.model.FB1.uncertainty
)
- s += "Orbital Pdot (PBDOT) = {} +/- {}\n".format(
- pd.to(u.dimensionless_unscaled),
- pderr.to(u.dimensionless_unscaled),
- )
+ s += f"Orbital Pdot (PBDOT) = {pbd.to(u.dimensionless_unscaled)} +/- {pbderr.to(u.dimensionless_unscaled)}\n"
ell1 = False
if binary.startswith("BinaryELL1"):
ell1 = True
- eps1 = ufloat(
- self.model.EPS1.quantity.value,
- self.model.EPS1.uncertainty.value,
- )
- eps2 = ufloat(
- self.model.EPS2.quantity.value,
- self.model.EPS2.uncertainty.value,
- )
+ eps1 = self.model.EPS1.as_ufloat()
+ eps2 = self.model.EPS2.as_ufloat()
tasc = ufloat(
# This is a time in MJD
self.model.TASC.quantity.mjd,
self.model.TASC.uncertainty.to(u.d).value,
)
if hasattr(self.model, "PB") and self.model.PB.value is not None:
- pb = ufloat(
- self.model.PB.quantity.to(u.d).value,
- self.model.PB.uncertainty.to(u.d).value,
- )
+ pb = self.model.PB.as_ufloat(u.d)
elif hasattr(self.model, "FB0") and self.model.FB0.value is not None:
- p, perr = pint.derived_quantities.pferrs(
- self.model.FB0.quantity, self.model.FB0.uncertainty
- )
- pb = ufloat(p.to(u.d).value, perr.to(u.d).value)
+ pb = 1 / self.model.FB0.as_ufloat(1 / u.d)
s += "Conversion from ELL1 parameters:\n"
ecc = um.sqrt(eps1**2 + eps2**2)
s += "ECC = {:P}\n".format(ecc)
om = um.atan2(eps1, eps2) * 180.0 / np.pi
if om < 0.0:
om += 360.0
- s += "OM = {:P} deg\n".format(om)
+ s += f"OM = {om:P} deg\n"
t0 = tasc + pb * om / 360.0
- s += "T0 = {:SP}\n".format(t0)
+ s += f"T0 = {t0:SP}\n"
- a1 = self.model.A1.quantity
- if a1 is None:
- a1 = 0 * pint.ls
+ a1 = (
+ self.model.A1.quantity
+ if self.model.A1.quantity is not None
+ else 0 * pint.ls
+ )
if self.is_wideband:
s += pint.utils.ELL1_check(
a1,
@@ -605,24 +570,24 @@ def get_derived_params(self):
outstring=True,
)
s += "\n"
-
+ if hasattr(self.model, "FB0") and self.model.FB0.value is not None:
+ pb, pberr = pint.derived_quantities.pferrs(
+ self.model.FB0.quantity, self.model.FB0.uncertainty
+ )
# Masses and inclination
- pb = p.to(u.d) if btx else self.model.PB.quantity
- pberr = perr.to(u.d) if btx else self.model.PB.uncertainty
+ pb = pb.to(u.d) if btx else self.model.PB.quantity
+ pberr = pberr.to(u.d) if btx else self.model.PB.uncertainty
if not self.model.A1.frozen:
pbs = ufloat(
pb.to(u.s).value,
pberr.to(u.s).value,
)
- a1 = ufloat(
- self.model.A1.quantity.to(pint.ls).value,
- self.model.A1.uncertainty.to(pint.ls).value,
- )
+ a1 = self.model.A1.as_ufloat(pint.ls)
# This is the mass function, done explicitly so that we get
# uncertainty propagation automatically.
# TODO: derived quantities funcs should take uncertainties
fm = 4.0 * np.pi**2 * a1**3 / (4.925490947e-6 * pbs**2)
- s += "Mass function = {:SP} Msun\n".format(fm)
+ s += f"Mass function = {fm:SP} Msun\n"
mcmed = pint.derived_quantities.companion_mass(
pb,
self.model.A1.quantity,
@@ -635,9 +600,7 @@ def get_derived_params(self):
i=90.0 * u.deg,
mp=1.4 * u.solMass,
)
- s += "Min / Median Companion mass (assuming Mpsr = 1.4 Msun) = {:.4f} / {:.4f} Msun\n".format(
- mcmin.value, mcmed.value
- )
+ s += f"Min / Median Companion mass (assuming Mpsr = 1.4 Msun) = {mcmin.value:.4f} / {mcmed.value:.4f} Msun\n"
if (
hasattr(self.model, "OMDOT")
@@ -667,23 +630,20 @@ def get_derived_params(self):
)
Mtot_err = max(abs(Mtot_hi - Mtot), abs(Mtot - Mtot_lo))
mt = ufloat(Mtot.value, Mtot_err.value)
- s += "Total mass, assuming GR, from OMDOT is {:SP} Msun\n".format(mt)
+ s += f"Total mass, assuming GR, from OMDOT is {mt:SP} Msun\n"
if (
hasattr(self.model, "SINI")
and self.model.SINI.quantity is not None
and (self.model.SINI.value >= 0.0 and self.model.SINI.value < 1.0)
):
- try:
+ with contextlib.suppress(TypeError, ValueError):
# Put this in a try in case SINI is UNSET or an illegal value
if not self.model.SINI.frozen:
- si = ufloat(
- self.model.SINI.quantity.value,
- self.model.SINI.uncertainty.value,
- )
- s += "From SINI in model:\n"
- s += " cos(i) = {:SP}\n".format(um.sqrt(1 - si**2))
- s += " i = {:SP} deg\n".format(um.asin(si) * 180.0 / np.pi)
+ si = self.model.SINI.as_ufloat()
+ s += f"From SINI in model:\n"
+ s += f" cos(i) = {um.sqrt(1 - si**2):SP}\n"
+ s += f" i = {um.asin(si) * 180.0 / np.pi:SP} deg\n"
psrmass = pint.derived_quantities.pulsar_mass(
pb,
@@ -691,9 +651,7 @@ def get_derived_params(self):
self.model.M2.quantity,
np.arcsin(self.model.SINI.quantity),
)
- s += "Pulsar mass (Shapiro Delay) = {}".format(psrmass)
- except (TypeError, ValueError):
- pass
+ s += f"Pulsar mass (Shapiro Delay) = {psrmass}"
return s
def print_summary(self):
@@ -736,10 +694,11 @@ def update_model(self, chi2=None):
self.model.NTOA.value = len(self.toas)
self.model.EPHEM.value = self.toas.ephem
self.model.DMDATA.value = hasattr(self.resids, "dm")
- if not self.toas.clock_corr_info["include_bipm"]:
- self.model.CLOCK.value = "TT(TAI)"
- else:
- self.model.CLOCK.value = f"TT({self.toas.clock_corr_info['bipm_version']})"
+ self.model.CLOCK.value = (
+ f"TT({self.toas.clock_corr_info['bipm_version']})"
+ if self.toas.clock_corr_info["include_bipm"]
+ else "TT(TAI)"
+ )
def reset_model(self):
"""Reset the current model to the initial model."""
@@ -864,7 +823,7 @@ def ftest(self, parameter, component, remove=False, full_output=False, maxiter=1
NB = not self.is_wideband
# Copy the fitter that we do not change the initial model and fitter
fitter_copy = copy.deepcopy(self)
- # We need the original degrees of freedome and chi-squared value
+ # We need the original degrees of freedom and chi-squared value
# Because this applies to nested models, model 1 must always have fewer parameters
if remove:
dof_2 = self.resids.dof
@@ -872,7 +831,7 @@ def ftest(self, parameter, component, remove=False, full_output=False, maxiter=1
else:
dof_1 = self.resids.dof
chi2_1 = self.resids.chi2
- # Single inputs are converted to lists to handle arb. number of parameteres
+ # Single inputs are converted to lists to handle arb. number of parameters
if type(parameter) is not list:
parameter = [parameter]
# also do the components
@@ -917,17 +876,17 @@ def ftest(self, parameter, component, remove=False, full_output=False, maxiter=1
fitter_copy.model, "{:}".format(parameter[ii].name)
).frozen = False
# Check if parameter is one that needs to be checked
- if parameter[ii].name in check_params.keys():
- if parameter[ii].value == 0.0:
- log.warning(
- "Default value for %s cannot be 0, resetting to %s"
- % (parameter[ii].name, check_params[parameter[ii].name])
- )
- parameter[ii].value = check_params[parameter[ii].name]
+ if (
+ parameter[ii].name in check_params
+ and parameter[ii].value == 0.0
+ ):
+ log.warning(
+ f"Default value for {parameter[ii].name} cannot be 0, resetting to {check_params[parameter[ii].name]}"
+ )
+ parameter[ii].value = check_params[parameter[ii].name]
getattr(
fitter_copy.model, "{:}".format(parameter[ii].name)
).value = parameter[ii].value
- # If not, add it to the model
else:
fitter_copy.model.components[component[ii]].add_param(
parameter[ii], setup=True
@@ -944,40 +903,39 @@ def ftest(self, parameter, component, remove=False, full_output=False, maxiter=1
# Now run the actual F-test
ft = FTest(chi2_1, dof_1, chi2_2, dof_2)
- if full_output:
- if remove:
- dof_test = dof_1
- chi2_test = chi2_1
- else:
- dof_test = dof_2
- chi2_test = chi2_2
- if NB:
- resid_rms_test = fitter_copy.resids.time_resids.std().to(u.us)
- resid_wrms_test = fitter_copy.resids.rms_weighted() # units: us
- return {
- "ft": ft,
- "resid_rms_test": resid_rms_test,
- "resid_wrms_test": resid_wrms_test,
- "chi2_test": chi2_test,
- "dof_test": dof_test,
- }
- else:
- # Return the dm and time resid values separately
- resid_rms_test = fitter_copy.resids.toa.time_resids.std().to(u.us)
- resid_wrms_test = fitter_copy.resids.toa.rms_weighted() # units: us
- dm_resid_rms_test = fitter_copy.resids.dm.resids.std()
- dm_resid_wrms_test = fitter_copy.resids.dm.rms_weighted()
- return {
- "ft": ft,
- "resid_rms_test": resid_rms_test,
- "resid_wrms_test": resid_wrms_test,
- "chi2_test": chi2_test,
- "dof_test": dof_test,
- "dm_resid_rms_test": dm_resid_rms_test,
- "dm_resid_wrms_test": dm_resid_wrms_test,
- }
- else:
+ if not full_output:
return {"ft": ft}
+ if remove:
+ dof_test = dof_1
+ chi2_test = chi2_1
+ else:
+ dof_test = dof_2
+ chi2_test = chi2_2
+ if NB:
+ resid_rms_test = fitter_copy.resids.time_resids.std().to(u.us)
+ resid_wrms_test = fitter_copy.resids.rms_weighted() # units: us
+ return {
+ "ft": ft,
+ "resid_rms_test": resid_rms_test,
+ "resid_wrms_test": resid_wrms_test,
+ "chi2_test": chi2_test,
+ "dof_test": dof_test,
+ }
+ else:
+ # Return the dm and time resid values separately
+ resid_rms_test = fitter_copy.resids.toa.time_resids.std().to(u.us)
+ resid_wrms_test = fitter_copy.resids.toa.rms_weighted() # units: us
+ dm_resid_rms_test = fitter_copy.resids.dm.resids.std()
+ dm_resid_wrms_test = fitter_copy.resids.dm.rms_weighted()
+ return {
+ "ft": ft,
+ "resid_rms_test": resid_rms_test,
+ "resid_wrms_test": resid_wrms_test,
+ "chi2_test": chi2_test,
+ "dof_test": dof_test,
+ "dm_resid_rms_test": dm_resid_rms_test,
+ "dm_resid_wrms_test": dm_resid_wrms_test,
+ }
def minimize_func(self, x, *args):
"""Wrapper function for the residual class.
@@ -987,7 +945,7 @@ def minimize_func(self, x, *args):
values, x, and a second optional tuple of input arguments. It returns
a quantity to be minimized (in this case chi^2).
"""
- self.set_params({k: v for k, v in zip(args, x)})
+ self.set_params(dict(zip(args, x)))
self.update_resids()
# Return chi^2
return self.resids.chi2
@@ -1017,7 +975,7 @@ def set_fitparams(self, *params):
if rn != "":
fit_params_name.append(rn)
else:
- raise ValueError("Unrecognized parameter {}".format(pn))
+ raise ValueError(f"Unrecognized parameter {pn}")
self.model.fit_params = fit_params_name
def get_allparams(self):
@@ -1148,11 +1106,8 @@ def take_step_model(self, step, lambda_=1):
new_model = copy.deepcopy(self.model)
for p, s in zip(self.params, step * lambda_):
try:
- try:
+ with contextlib.suppress(ValueError):
log.trace(f"Adjusting {getattr(self.model, p)} by {s}")
- except ValueError:
- # I don't know why this fails with multiprocessing, but bypass if it does
- pass
pm = getattr(new_model, p)
if pm.value is None:
pm.value = 0
@@ -1205,7 +1160,7 @@ def fit_toas(
one; if the new model is invalid or worse than the current one, it
tries taking a shorter step in the same direction. This can exit if it
exceeds the maximum number of iterations or if improvement is not
- possible even with very short steps, or it can exit successully if a
+ possible even with very short steps, or it can exit successfully if a
full-size step is taken and it does not decrease the ``chi2`` by much.
The attribute ``self.converged`` is set to True or False depending on
@@ -1249,16 +1204,15 @@ def fit_toas(
f"chi2 increased from {current_state.chi2} to {new_state.chi2} "
f"when trying to take a step with lambda {lambda_}"
)
- else:
- log.trace(
- f"Iteration {i}: "
- f"Updating state, chi2 goes down by {chi2_decrease} "
- f"from {current_state.chi2} "
- f"to {new_state.chi2}"
- )
- exception = None
- current_state = new_state
- break
+ log.trace(
+ f"Iteration {i}: "
+ f"Updating state, chi2 goes down by {chi2_decrease} "
+ f"from {current_state.chi2} "
+ f"to {new_state.chi2}"
+ )
+ exception = None
+ current_state = new_state
+ break
except InvalidModelParameters as e:
# This could be an exception evaluating new_state.chi2 or an increase in value
# If bad parameter values escape, look in ModelState.resids for the except
@@ -1286,7 +1240,7 @@ def fit_toas(
break
else:
log.debug(
- f"Stopping because maxmum number of iterations ({maxiter}) reached"
+ f"Stopping because maximum number of iterations ({maxiter}) reached"
)
self.current_state = best_state
# collect results
@@ -1301,11 +1255,9 @@ def fit_toas(
)
for p, e in zip(self.current_state.params, self.errors):
try:
- try:
+ # I don't know why this fails with multiprocessing, but bypass if it does
+ with contextlib.suppress(ValueError):
log.trace(f"Setting {getattr(self.model, p)} uncertainty to {e}")
- except ValueError:
- # I don't know why this fails with multiprocessing, but bypass if it does
- pass
pm = getattr(self.model, p)
except AttributeError:
if p != "Offset":
@@ -1400,9 +1352,10 @@ def step(self):
self.units = units
self.scaled_resids = scaled_resids
# TODO: seems like doing this on every iteration is wasteful, and we should just do it once and then update the matrix
- covariance_matrix_labels = {}
- for i, (param, unit) in enumerate(zip(params, units)):
- covariance_matrix_labels[param] = (i, i + 1, unit)
+ covariance_matrix_labels = {
+ param: (i, i + 1, unit)
+ for i, (param, unit) in enumerate(zip(params, units))
+ }
# covariance matrix is 2D and symmetric
covariance_matrix_labels = [covariance_matrix_labels] * 2
self.parameter_covariance_matrix_labels = covariance_matrix_labels
@@ -1484,9 +1437,10 @@ def step(self):
self.params = params
self.units = units
# TODO: seems like doing this on every iteration is wasteful, and we should just do it once and then update the matrix
- covariance_matrix_labels = {}
- for i, (param, unit) in enumerate(zip(params, units)):
- covariance_matrix_labels[param] = (i, i + 1, unit)
+ covariance_matrix_labels = {
+ param: (i, i + 1, unit)
+ for i, (param, unit) in enumerate(zip(params, units))
+ }
# covariance matrix is 2D and symmetric
covariance_matrix_labels = [covariance_matrix_labels] * 2
self.parameter_covariance_matrix_labels = covariance_matrix_labels
@@ -1637,10 +1591,10 @@ def fit_toas(self, maxiter=10, threshold=0, full_cov=False, debug=False, **kwarg
r = super().fit_toas(maxiter=maxiter, debug=debug, **kwargs)
# FIXME: set up noise residuals et cetera
# Compute the noise realizations if possible
- ntmpar = len(self.model.free_params)
if not self.full_cov:
noise_dims = self.model.noise_model_dimensions(self.toas)
noise_resids = {}
+ ntmpar = len(self.model.free_params)
for comp in noise_dims:
# The first column of designmatrix is "offset", add 1 to match
# the indices of noise designmatrix
@@ -1655,13 +1609,13 @@ def fit_toas(self, maxiter=10, threshold=0, full_cov=False, debug=False, **kwarg
if debug:
setattr(
self.resids,
- comp + "_M",
+ f"{comp}_M",
(
self.current_state.M[:, p0:p1],
self.current_state.xhat[p0:p1],
),
)
- setattr(self.resids, comp + "_M_index", (p0, p1))
+ setattr(self.resids, f"{comp}_M_index", (p0, p1))
self.resids.noise_resids = noise_resids
if debug:
setattr(self.resids, "norm", self.current_state.norm)
@@ -1857,9 +1811,10 @@ def parameter_covariance_matrix(self):
# make sure we compute the SVD
xvar = np.dot(self.Vt.T / self.s, self.Vt)
# is this the best place to do this?
- covariance_matrix_labels = {}
- for i, (param, unit) in enumerate(zip(self.params, self.units)):
- covariance_matrix_labels[param] = (i, i + 1, unit)
+ covariance_matrix_labels = {
+ param: (i, i + 1, unit)
+ for i, (param, unit) in enumerate(zip(self.params, self.units))
+ }
# covariance matrix is 2D and symmetric
covariance_matrix_labels = [covariance_matrix_labels] * 2
@@ -1926,11 +1881,11 @@ def fit_toas(
self.full_cov = full_cov
# FIXME: set up noise residuals et cetera
r = super().fit_toas(maxiter=maxiter, debug=debug, **kwargs)
- # Compute the noise realizations if possible
- ntmpar = len(self.model.free_params)
+ # Compute the noise realizations if possibl
if not self.full_cov:
noise_dims = self.model.noise_model_dimensions(self.toas)
noise_resids = {}
+ ntmpar = len(self.model.free_params)
for comp in noise_dims:
# The first column of designmatrix is "offset", add 1 to match
# the indices of noise designmatrix
@@ -1945,13 +1900,13 @@ def fit_toas(
if debug:
setattr(
self.resids,
- comp + "_M",
+ f"{comp}_M",
(
self.current_state.M[:, p0:p1],
self.current_state.xhat[p0:p1],
),
)
- setattr(self.resids, comp + "_M_index", (p0, p1))
+ setattr(self.resids, f"{comp}_M_index", (p0, p1))
self.resids.noise_resids = noise_resids
if debug:
setattr(self.resids, "norm", self.current_state.norm)
@@ -2037,7 +1992,7 @@ def fit_toas(self, maxiter=1, threshold=None, debug=False):
self.model.validate()
self.model.validate_toas(self.toas)
chi2 = 0
- for i in range(maxiter):
+ for _ in range(maxiter):
fitp = self.model.get_params_dict("free", "quantity")
fitpv = self.model.get_params_dict("free", "num")
fitperrs = self.model.get_params_dict("free", "uncertainty")
@@ -2109,10 +2064,10 @@ def fit_toas(self, maxiter=1, threshold=None, debug=False):
sigma_cov = (sigma_var / errors).T / errors
# covariance matrix = variances in diagonal, used for gaussian random models
covariance_matrix = sigma_var
- # TODO: seems like doing this on every iteration is wasteful, and we should just do it once and then update the matrix
- covariance_matrix_labels = {}
- for i, (param, unit) in enumerate(zip(params, units)):
- covariance_matrix_labels[param] = (i, i + 1, unit)
+ covariance_matrix_labels = {
+ param: (i, i + 1, unit)
+ for i, (param, unit) in enumerate(zip(params, units))
+ }
# covariance matrix is 2D and symmetric
covariance_matrix_labels = [
covariance_matrix_labels
@@ -2132,7 +2087,7 @@ def fit_toas(self, maxiter=1, threshold=None, debug=False):
# dpars = V s^-1 U^T r
# Scaling by fac recovers original units
dpars = np.dot(Vt.T, np.dot(U.T, residuals) / s) / fac
- for ii, pn in enumerate(fitp.keys()):
+ for pn in fitp.keys():
uind = params.index(pn) # Index of designmatrix
un = 1.0 / (units[uind]) # Unit in designmatrix
un *= u.s
@@ -2302,12 +2257,10 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
dpars = xhat / norm
errs = np.sqrt(np.diag(xvar)) / norm
covmat = (xvar / norm).T / norm
- # self.covariance_matrix = covmat
- # self.correlation_matrix = (covmat / errs).T / errs
- # TODO: seems like doing this on every iteration is wasteful, and we should just do it once and then update the matrix
- covariance_matrix_labels = {}
- for i, (param, unit) in enumerate(zip(params, units)):
- covariance_matrix_labels[param] = (i, i + 1, unit)
+ covariance_matrix_labels = {
+ param: (i, i + 1, unit)
+ for i, (param, unit) in enumerate(zip(params, units))
+ }
# covariance matrix is 2D and symmetric
covariance_matrix_labels = [covariance_matrix_labels] * covmat.ndim
self.parameter_covariance_matrix = CovarianceMatrix(
@@ -2317,7 +2270,7 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
(covmat / errs).T / errs, covariance_matrix_labels
)
- for ii, pn in enumerate(fitp.keys()):
+ for pn in fitp.keys():
uind = params.index(pn) # Index of designmatrix
un = 1.0 / (units[uind]) # Unit in designmatrix
un *= u.s
@@ -2343,8 +2296,8 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
p1 = p0 + noise_dims[comp][1]
noise_resids[comp] = np.dot(M[:, p0:p1], xhat[p0:p1]) * u.s
if debug:
- setattr(self.resids, comp + "_M", (M[:, p0:p1], xhat[p0:p1]))
- setattr(self.resids, comp + "_M_index", (p0, p1))
+ setattr(self.resids, f"{comp}_M", (M[:, p0:p1], xhat[p0:p1]))
+ setattr(self.resids, f"{comp}_M_index", (p0, p1))
self.resids.noise_resids = noise_resids
if debug:
setattr(self.resids, "norm", norm)
@@ -2410,21 +2363,17 @@ def __init__(
# Get the makers for fitting parts.
self.reset_model()
self.resids_init = copy.deepcopy(self.resids)
- self.designmatrix_makers = []
- for data_resids in self.resids.residual_objs.values():
- self.designmatrix_makers.append(
- DesignMatrixMaker(data_resids.residual_type, data_resids.unit)
- )
-
+ self.designmatrix_makers = [
+ DesignMatrixMaker(data_resids.residual_type, data_resids.unit)
+ for data_resids in self.resids.residual_objs.values()
+ ]
# Add noise design matrix maker
self.noise_designmatrix_maker = DesignMatrixMaker("toa_noise", u.s)
#
- self.covariancematrix_makers = []
- for data_resids in self.resids.residual_objs.values():
- self.covariancematrix_makers.append(
- CovarianceMatrixMaker(data_resids.residual_type, data_resids.unit)
- )
-
+ self.covariancematrix_makers = [
+ CovarianceMatrixMaker(data_resids.residual_type, data_resids.unit)
+ for data_resids in self.resids.residual_objs.values()
+ ]
self.is_wideband = True
self.method = "General_Data_Fitter"
@@ -2455,29 +2404,30 @@ def get_designmatrix(self):
design_matrixs = []
fit_params = self.model.free_params
if len(self.fit_data) == 1:
- for ii, dmatrix_maker in enumerate(self.designmatrix_makers):
- design_matrixs.append(
- dmatrix_maker(self.fit_data[0], self.model, fit_params, offset=True)
- )
+ design_matrixs.extend(
+ dmatrix_maker(self.fit_data[0], self.model, fit_params, offset=True)
+ for dmatrix_maker in self.designmatrix_makers
+ )
else:
- for ii, dmatrix_maker in enumerate(self.designmatrix_makers):
- design_matrixs.append(
- dmatrix_maker(
- self.fit_data[ii], self.model, fit_params, offset=True
- )
- )
+ design_matrixs.extend(
+ dmatrix_maker(self.fit_data[ii], self.model, fit_params, offset=True)
+ for ii, dmatrix_maker in enumerate(self.designmatrix_makers)
+ )
return combine_design_matrices_by_quantity(design_matrixs)
def get_noise_covariancematrix(self):
# TODO This needs to be more general
cov_matrixs = []
if len(self.fit_data) == 1:
- for ii, cmatrix_maker in enumerate(self.covariancematrix_makers):
- cov_matrixs.append(cmatrix_maker(self.fit_data[0], self.model))
+ cov_matrixs.extend(
+ cmatrix_maker(self.fit_data[0], self.model)
+ for cmatrix_maker in self.covariancematrix_makers
+ )
else:
- for ii, cmatrix_maker in enumerate(self.covariancematrix_makers):
- cov_matrixs.append(cmatrix_maker(self.fit_data[ii], self.model))
-
+ cov_matrixs.extend(
+ cmatrix_maker(self.fit_data[ii], self.model)
+ for ii, cmatrix_maker in enumerate(self.covariancematrix_makers)
+ )
return combine_covariance_matrix(cov_matrixs)
def get_data_uncertainty(self, data_name, data_obj):
@@ -2503,7 +2453,7 @@ def scaled_all_sigma(self):
scaled_sigmas = []
sigma_units = []
for ii, fd_name in enumerate(self.fit_data_names):
- func_name = "scaled_{}_uncertainty".format(fd_name)
+ func_name = f"scaled_{fd_name}_uncertainty"
sigma_units.append(self.resids.residual_objs[fd_name].unit)
if hasattr(self.model, func_name):
scale_func = getattr(self.model, func_name)
@@ -2660,9 +2610,10 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
errs = np.sqrt(np.diag(xvar)) / norm
covmat = (xvar / norm).T / norm
# TODO: seems like doing this on every iteration is wasteful, and we should just do it once and then update the matrix
- covariance_matrix_labels = {}
- for i, (param, unit) in enumerate(zip(params, units)):
- covariance_matrix_labels[param] = (i, i + 1, unit)
+ covariance_matrix_labels = {
+ param: (i, i + 1, unit)
+ for i, (param, unit) in enumerate(zip(params, units))
+ }
# covariance matrix is 2D and symmetric
covariance_matrix_labels = [covariance_matrix_labels] * covmat.ndim
self.parameter_covariance_matrix = CovarianceMatrix(
@@ -2675,7 +2626,7 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
# self.covariance_matrix = covmat
# self.correlation_matrix = (covmat / errs).T / errs
- for ii, pn in enumerate(fitp.keys()):
+ for pn in fitp.keys():
uind = params.index(pn) # Index of designmatrix
# Here we use design matrix's label, so the unit goes to normal.
# instead of un = 1 / (units[uind])
@@ -2702,8 +2653,8 @@ def fit_toas(self, maxiter=1, threshold=0, full_cov=False, debug=False):
p1 = p0 + noise_dims[comp][1]
noise_resids[comp] = np.dot(M[:, p0:p1], xhat[p0:p1]) * u.s
if debug:
- setattr(self.resids, comp + "_M", (M[:, p0:p1], xhat[p0:p1]))
- setattr(self.resids, comp + "_M_index", (p0, p1))
+ setattr(self.resids, f"{comp}_M", (M[:, p0:p1], xhat[p0:p1]))
+ setattr(self.resids, f"{comp}_M_index", (p0, p1))
self.resids.noise_resids = noise_resids
if debug:
setattr(self.resids, "norm", norm)
@@ -2784,9 +2735,9 @@ def fit_toas(
chi2_decrease = current_state.chi2 - new_state.chi2
if chi2_decrease < -min_chi2_decrease:
lambda_ *= (
- lambda_factor_increase
- if not ill_conditioned
- else lambda_factor_invalid
+ lambda_factor_invalid
+ if ill_conditioned
+ else lambda_factor_increase
)
log.trace(
f"Iteration {i}: chi2 increased from {current_state.chi2} "
@@ -2899,10 +2850,10 @@ def update_from_state(self, state, debug=False):
self.update_model(state.chi2)
# Compute the noise realizations if possible
- ntmpar = len(self.model.free_params)
if not self.full_cov:
noise_dims = self.model.noise_model_dimensions(self.toas)
noise_resids = {}
+ ntmpar = len(self.model.free_params)
for comp in noise_dims:
# The first column of designmatrix is "offset", add 1 to match
# the indices of noise designmatrix
@@ -2911,9 +2862,9 @@ def update_from_state(self, state, debug=False):
noise_resids[comp] = np.dot(state.M[:, p0:p1], state.xhat[p0:p1]) * u.s
if debug:
setattr(
- self.resids, comp + "_M", (state.M[:, p0:p1], state.xhat[p0:p1])
+ self.resids, f"{comp}_M", (state.M[:, p0:p1], state.xhat[p0:p1])
)
- setattr(self.resids, comp + "_M_index", (p0, p1))
+ setattr(self.resids, f"{comp}_M_index", (p0, p1))
self.resids.noise_resids = noise_resids
if debug:
setattr(self.resids, "norm", state.norm)
diff --git a/src/pint/gridutils.py b/src/pint/gridutils.py
index ff7fd66c2..4a41d1c1d 100644
--- a/src/pint/gridutils.py
+++ b/src/pint/gridutils.py
@@ -3,8 +3,10 @@
import copy
import multiprocessing
import subprocess
+import sys
import numpy as np
+
from loguru import logger as log
try:
@@ -15,7 +17,7 @@
from astropy.utils.console import ProgressBar
from pint import fitter
-
+from pint.observatory import clock_file
__all__ = ["doonefit", "grid_chisq", "grid_chisq_derived"]
@@ -24,6 +26,11 @@ def hostinfo():
return subprocess.check_output("uname -a", shell=True)
+def set_log(logger_):
+ global log
+ log = logger_
+
+
class WrappedFitter:
"""Worker class to compute one fit with specified parameters fixed but passing other parameters to fit_toas()"""
@@ -58,6 +65,12 @@ def doonefit(self, parnames, parvalues, extraparnames=[]):
"""
# Make a full copy of the fitter to work with
myftr = copy.deepcopy(self.ftr)
+ # copy the log to all imported modules
+ # this makes them respect the logger settings
+ for m in sys.modules:
+ if m.startswith("pint") and hasattr(sys.modules[m], "log"):
+ setattr(sys.modules[m], "log", log)
+
parstrings = []
for parname, parvalue in zip(parnames, parvalues):
# Freeze the params we are going to grid over and set their values
@@ -281,14 +294,6 @@ def grid_chisq(
.. [1] https://mpi4py.readthedocs.io/en/stable/mpi4py.futures.html#mpipoolexecutor
.. [2] https://github.com/sampsyo/clusterfutures
"""
- if isinstance(executor, concurrent.futures.Executor):
- # the executor has already been created
- executor = executor
- elif executor is None and (ncpu is None or ncpu > 1):
- # make the default type of Executor
- if ncpu is None:
- ncpu = multiprocessing.cpu_count()
- executor = concurrent.futures.ProcessPoolExecutor(max_workers=ncpu)
# Save the current model so we can tweak it for gridding, then restore it at the end
savemod = ftr.model
@@ -301,6 +306,17 @@ def grid_chisq(
wftr = WrappedFitter(ftr, **fitargs)
+ if isinstance(executor, concurrent.futures.Executor):
+ # the executor has already been created
+ executor = executor
+ elif executor is None and (ncpu is None or ncpu > 1):
+ # make the default type of Executor
+ if ncpu is None:
+ ncpu = multiprocessing.cpu_count()
+ executor = concurrent.futures.ProcessPoolExecutor(
+ max_workers=ncpu, initializer=set_log, initargs=(log,)
+ )
+
# All other unfrozen parameters will be fitted for at each grid point
out = np.meshgrid(*parvalues)
chi2 = np.zeros(out[0].shape)
diff --git a/src/pint/logging.py b/src/pint/logging.py
index 4eb40f286..0f4f67021 100644
--- a/src/pint/logging.py
+++ b/src/pint/logging.py
@@ -224,7 +224,7 @@ def filter(self, record):
if not self.onlyonce[m]:
self.onlyonce[m] = [message_to_save]
return True
- elif not (message_to_save in self.onlyonce[m]):
+ elif message_to_save not in self.onlyonce[m]:
self.onlyonce[m].append(message_to_save)
return True
return False
@@ -311,6 +311,7 @@ def setup(
filter=filter,
format=format,
colorize=usecolors,
+ enqueue=True,
)
# change default colors
for level in colors:
diff --git a/src/pint/mcmc_fitter.py b/src/pint/mcmc_fitter.py
index 0a1797a98..5b28a23dc 100644
--- a/src/pint/mcmc_fitter.py
+++ b/src/pint/mcmc_fitter.py
@@ -89,7 +89,7 @@ def set_priors_basic(ftr, priorerrfact=10.0):
"""
fkeys, fvals, ferrs = ftr.fitkeys, ftr.fitvals, ftr.fiterrs
for key, val, err in zip(fkeys, fvals, ferrs):
- if key == "SINI" or key == "E" or key == "ECC":
+ if key in ["SINI", "E", "ECC"]:
getattr(ftr.model, key).prior = Prior(uniform(0.0, 1.0))
elif key == "PX":
getattr(ftr.model, key).prior = Prior(uniform(0.0, 10.0))
@@ -99,7 +99,7 @@ def set_priors_basic(ftr, priorerrfact=10.0):
if err == 0 and not getattr(ftr.model, key).frozen:
ftr.priors_set = False
raise ValueError(
- "Parameter %s does not have uncertainty in par file" % key
+ f"Parameter {key} does not have uncertainty in par file"
)
getattr(ftr.model, key).prior = Prior(
norm(loc=float(val), scale=float(err * priorerrfact))
@@ -389,10 +389,7 @@ def prof_vs_weights(self, nbins=50, use_weights=False):
if nphotons <= 0:
hval = 0
else:
- if use_weights:
- hval = hmw(phss[good], weights=wgts)
- else:
- hval = hm(phss[good])
+ hval = hmw(phss[good], weights=wgts) if use_weights else hm(phss[good])
htests.append(hval)
if ii > 0 and ii % 2 == 0 and ii < 20:
r, c = ((ii - 2) / 2) / 3, ((ii - 2) / 2) % 3
@@ -412,22 +409,22 @@ def prof_vs_weights(self, nbins=50, use_weights=False):
if r == 2:
ax[r][c].set_xlabel("Phase")
f.suptitle(
- "%s: Minwgt / H-test / Approx # events" % self.model.PSR.value,
+ f"{self.model.PSR.value}: Minwgt / H-test / Approx # events",
fontweight="bold",
)
if use_weights:
- plt.savefig(self.model.PSR.value + "_profs_v_wgtcut.png")
+ plt.savefig(f"{self.model.PSR.value}_profs_v_wgtcut.png")
else:
- plt.savefig(self.model.PSR.value + "_profs_v_wgtcut_unweighted.png")
+ plt.savefig(f"{self.model.PSR.value}_profs_v_wgtcut_unweighted.png")
plt.close()
plt.plot(weights, htests, "k")
plt.xlabel("Min Weight")
plt.ylabel("H-test")
plt.title(self.model.PSR.value)
if use_weights:
- plt.savefig(self.model.PSR.value + "_htest_v_wgtcut.png")
+ plt.savefig(f"{self.model.PSR.value}_htest_v_wgtcut.png")
else:
- plt.savefig(self.model.PSR.value + "_htest_v_wgtcut_unweighted.png")
+ plt.savefig(f"{self.model.PSR.value}_htest_v_wgtcut_unweighted.png")
plt.close()
def plot_priors(self, chains, burnin, bins=100, scale=False):
@@ -642,33 +639,32 @@ def get_template_vals(self, phases, index):
)
if isinstance(self.templates[index], LCTemplate):
return self.templates[index](phases, use_cache=True)
- else:
- if self.xtemps[index] is None:
- ltemp = len(self.templates[index])
- self.xtemps[index] = np.arange(ltemp) * 1.0 / ltemp
- return np.interp(
- phases,
- self.xtemps[index],
- self.templates[index],
- right=self.templates[index][0],
- )
+ if self.xtemps[index] is None:
+ ltemp = len(self.templates[index])
+ self.xtemps[index] = np.arange(ltemp) * 1.0 / ltemp
+ return np.interp(
+ phases,
+ self.xtemps[index],
+ self.templates[index],
+ right=self.templates[index][0],
+ )
def get_weights(self, index=None):
- if not index is None:
+ if index is not None:
return self.weights[index]
- else:
- wgts = np.zeros(len(self.toas.table))
- curr = 0
- for i in range(len(self.toas_list)):
- ts = self.toas_list[i]
- nxt = curr + len(ts.table)
- print(curr, nxt, len(ts.table))
- if self.weights[i] is None:
- wgts[curr:nxt] = 1.0 * self.set_weights[i]
- else:
- wgts[curr:nxt] = self.weights[i] * self.set_weights[i]
- curr = nxt
- return wgts
+ wgts = np.zeros(len(self.toas.table))
+ curr = 0
+ for i in range(len(self.toas_list)):
+ ts = self.toas_list[i]
+ nxt = curr + len(ts.table)
+ print(curr, nxt, len(ts.table))
+ wgts[curr:nxt] = (
+ 1.0 * self.set_weights[i]
+ if self.weights[i] is None
+ else self.weights[i] * self.set_weights[i]
+ )
+ curr = nxt
+ return wgts
def lnlikelihood(self, fitter, theta):
"""Sum over the log-likelihood functions for each dataset, multiply by weights in the sum.
diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py
index e2815cb92..ca7900338 100644
--- a/src/pint/models/__init__.py
+++ b/src/pint/models/__init__.py
@@ -21,12 +21,13 @@
# Import all standard model components here
from pint.models.astrometry import AstrometryEcliptic, AstrometryEquatorial
from pint.models.binary_bt import BinaryBT
-from pint.models.binary_dd import BinaryDD
+from pint.models.binary_dd import BinaryDD, BinaryDDS, BinaryDDGR
from pint.models.binary_ddk import BinaryDDK
from pint.models.binary_ell1 import BinaryELL1, BinaryELL1H, BinaryELL1k
from pint.models.dispersion_model import DispersionDM, DispersionDMX
from pint.models.frequency_dependent import FD
from pint.models.glitch import Glitch
+from pint.models.phase_offset import PhaseOffset
from pint.models.piecewise import PiecewiseSpindown
from pint.models.ifunc import IFunc
from pint.models.jump import DelayJump, PhaseJump
diff --git a/src/pint/models/absolute_phase.py b/src/pint/models/absolute_phase.py
index 0829b863f..9d319e2eb 100644
--- a/src/pint/models/absolute_phase.py
+++ b/src/pint/models/absolute_phase.py
@@ -59,7 +59,7 @@ def validate(self):
raise MissingParameter(
"AbsPhase",
"TZRMJD",
- "TZRMJD is required " "to compute the absolute phase. ",
+ "TZRMJD is required to compute the absolute phase.",
)
if self.TZRSITE.value is None:
self.TZRSITE.value = "ssb"
@@ -81,16 +81,15 @@ def get_TZR_toa(self, toas):
"""
clkc_info = toas.clock_corr_info
# If we have cached the TZR TOA and all the TZR* and clock info has not changed, then don't rebuild it
- if self.tz_cache is not None:
- if (
- self.tz_clkc_info["include_bipm"] == clkc_info["include_bipm"]
- and self.tz_clkc_info["include_gps"] == clkc_info["include_gps"]
- and self.tz_planets == toas.planets
- and self.tz_ephem == toas.ephem
- and self.tz_hash
- == hash((self.TZRMJD.value, self.TZRSITE.value, self.TZRFRQ.value))
- ):
- return self.tz_cache
+ if self.tz_cache is not None and (
+ self.tz_clkc_info["include_bipm"] == clkc_info["include_bipm"]
+ and self.tz_clkc_info["include_gps"] == clkc_info["include_gps"]
+ and self.tz_planets == toas.planets
+ and self.tz_ephem == toas.ephem
+ and self.tz_hash
+ == hash((self.TZRMJD.value, self.TZRSITE.value, self.TZRFRQ.value))
+ ):
+ return self.tz_cache
# Otherwise we have to build the TOA and apply clock corrections
# NOTE: Using TZRMJD.quantity.jd[1,2] so that the time scale can be properly
# set to the TZRSITE default timescale (e.g. UTC for TopoObs and TDB for SSB)
@@ -115,6 +114,7 @@ def get_TZR_toa(self, toas):
include_gps=clkc_info["include_gps"],
ephem=toas.ephem,
planets=toas.planets,
+ tzr=True,
)
log.debug("Done with TZR_toa")
self.tz_cache = tz
diff --git a/src/pint/models/astrometry.py b/src/pint/models/astrometry.py
index b98f638aa..38b5ac7c6 100644
--- a/src/pint/models/astrometry.py
+++ b/src/pint/models/astrometry.py
@@ -113,10 +113,7 @@ def sun_angle(self, toas, heliocenter=True, also_distance=False):
r = (osv**2).sum(axis=1) ** 0.5
osv /= r[:, None]
cos = (osv * psr_vec).sum(axis=1)
- if also_distance:
- return np.arccos(cos), r
- else:
- return np.arccos(cos)
+ return (np.arccos(cos), r) if also_distance else np.arccos(cos)
def barycentric_radio_freq(self, toas):
raise NotImplementedError
@@ -150,17 +147,13 @@ def solar_system_geometric_delay(self, toas, acc_delay=None):
def get_d_delay_quantities(self, toas):
"""Calculate values needed for many d_delay_d_param functions"""
- # TODO: Move all these calculations in a separate class for elegance
- rd = dict()
-
# TODO: Should delay not have units of u.second?
delay = self._parent.delay(toas)
# TODO: tbl['tdbld'].quantity should have units of u.day
# NOTE: Do we need to include the delay here?
tbl = toas.table
- rd["epoch"] = tbl["tdbld"].quantity * u.day # - delay * u.second
-
+ rd = {"epoch": tbl["tdbld"].quantity * u.day}
# Distance from SSB to observatory, and from SSB to psr
ssb_obs = tbl["ssb_obs_pos"].quantity
ssb_psr = self.ssb_to_psb_xyz_ICRS(epoch=np.array(rd["epoch"]))
@@ -286,7 +279,7 @@ def __init__(self):
)
self.set_special_params(["RAJ", "DECJ", "PMRA", "PMDEC"])
for param in ["RAJ", "DECJ", "PMRA", "PMDEC"]:
- deriv_func_name = "d_delay_astrometry_d_" + param
+ deriv_func_name = f"d_delay_astrometry_d_{param}"
func = getattr(self, deriv_func_name)
self.register_deriv_funcs(func, param)
@@ -352,21 +345,19 @@ def get_psr_coords(self, epoch=None):
obstime=self.POSEPOCH.quantity,
frame=coords.ICRS,
)
- else:
- if isinstance(epoch, Time):
- newepoch = epoch
- else:
- newepoch = Time(epoch, scale="tdb", format="mjd")
- position_now = add_dummy_distance(self.get_psr_coords())
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", ErfaWarning)
- # for the most part the dummy distance should remove any potential erfa warnings
- # but for some very large proper motions that does not quite work
- # so we catch the warnings
- position_then = position_now.apply_space_motion(new_obstime=newepoch)
- position_then = remove_dummy_distance(position_then)
-
- return position_then
+ newepoch = (
+ epoch if isinstance(epoch, Time) else Time(epoch, scale="tdb", format="mjd")
+ )
+ position_now = add_dummy_distance(self.get_psr_coords())
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", ErfaWarning)
+ # for the most part the dummy distance should remove any potential erfa warnings
+ # but for some very large proper motions that does not quite work
+ # so we catch the warnings
+ position_then = position_now.apply_space_motion(new_obstime=newepoch)
+ position_then = remove_dummy_distance(position_then)
+
+ return position_then
def coords_as_ICRS(self, epoch=None):
"""Return the pulsar's ICRS coordinates as an astropy coordinate object."""
@@ -392,13 +383,12 @@ def coords_as_GAL(self, epoch=None):
return pos_icrs.transform_to(coords.Galactic)
def get_params_as_ICRS(self):
- result = {
+ return {
"RAJ": self.RAJ.quantity,
"DECJ": self.DECJ.quantity,
"PMRA": self.PMRA.quantity,
"PMDEC": self.PMDEC.quantity,
}
- return result
def d_delay_astrometry_d_RAJ(self, toas, param="", acc_delay=None):
"""Calculate the derivative wrt RAJ
@@ -650,7 +640,7 @@ def __init__(self):
self.set_special_params(["ELONG", "ELAT", "PMELONG", "PMELAT"])
for param in ["ELAT", "ELONG", "PMELAT", "PMELONG"]:
- deriv_func_name = "d_delay_astrometry_d_" + param
+ deriv_func_name = f"d_delay_astrometry_d_{param}"
func = getattr(self, deriv_func_name)
self.register_deriv_funcs(func, param)
@@ -702,11 +692,10 @@ def get_psr_coords(self, epoch=None):
"""
try:
obliquity = OBL[self.ECL.value]
- except KeyError:
+ except KeyError as e:
raise ValueError(
- "No obliquity " + str(self.ECL.value) + " provided. "
- "Check your pint/datafile/ecliptic.dat file."
- )
+ f"No obliquity {str(self.ECL.value)} provided. Check your pint/datafile/ecliptic.dat file."
+ ) from e
if epoch is None or (self.PMELONG.value == 0.0 and self.PMELAT.value == 0.0):
# Compute only once
return coords.SkyCoord(
@@ -718,17 +707,14 @@ def get_psr_coords(self, epoch=None):
obstime=self.POSEPOCH.quantity,
frame=PulsarEcliptic,
)
- else:
# Compute for each time because there is proper motion
- if isinstance(epoch, Time):
- newepoch = epoch
- else:
- newepoch = Time(epoch, scale="tdb", format="mjd")
- position_now = add_dummy_distance(self.get_psr_coords())
- position_then = remove_dummy_distance(
- position_now.apply_space_motion(new_obstime=newepoch)
- )
- return position_then
+ newepoch = (
+ epoch if isinstance(epoch, Time) else Time(epoch, scale="tdb", format="mjd")
+ )
+ position_now = add_dummy_distance(self.get_psr_coords())
+ return remove_dummy_distance(
+ position_now.apply_space_motion(new_obstime=newepoch)
+ )
def coords_as_ICRS(self, epoch=None):
"""Return the pulsar's ICRS coordinates as an astropy coordinate object."""
@@ -755,15 +741,17 @@ def coords_as_ECL(self, epoch=None, ecl=None):
def get_d_delay_quantities_ecliptical(self, toas):
"""Calculate values needed for many d_delay_d_param functions."""
# TODO: Move all these calculations in a separate class for elegance
- rd = dict()
+
# From the earth_ra dec to earth_elong and elat
try:
obliquity = OBL[self.ECL.value]
- except KeyError:
+ except KeyError as e:
raise ValueError(
- "No obliquity " + self.ECL.value + " provided. "
- "Check your pint/datafile/ecliptic.dat file."
- )
+ (
+ f"No obliquity {self.ECL.value}" + " provided. "
+ "Check your pint/datafile/ecliptic.dat file."
+ )
+ ) from e
rd = self.get_d_delay_quantities(toas)
coords_icrs = coords.ICRS(ra=rd["earth_ra"], dec=rd["earth_dec"])
@@ -774,14 +762,14 @@ def get_d_delay_quantities_ecliptical(self, toas):
return rd
def get_params_as_ICRS(self):
- result = dict()
pv_ECL = self.get_psr_coords()
pv_ICRS = pv_ECL.transform_to(coords.ICRS)
- result["RAJ"] = pv_ICRS.ra.to(u.hourangle)
- result["DECJ"] = pv_ICRS.dec
- result["PMRA"] = pv_ICRS.pm_ra_cosdec
- result["PMDEC"] = pv_ICRS.pm_dec
- return result
+ return {
+ "RAJ": pv_ICRS.ra.to(u.hourangle),
+ "DECJ": pv_ICRS.dec,
+ "PMRA": pv_ICRS.pm_ra_cosdec,
+ "PMDEC": pv_ICRS.pm_dec,
+ }
def d_delay_astrometry_d_ELONG(self, toas, param="", acc_delay=None):
"""Calculate the derivative wrt RAJ.
@@ -793,7 +781,7 @@ def d_delay_astrometry_d_ELONG(self, toas, param="", acc_delay=None):
ae = Earth right ascension
dp = pulsar declination
aa = pulsar right ascension
- r = distance from SSB to Earh
+ r = distance from SSB to Earth
c = speed of light
delay = r*[cos(de)*cos(dp)*cos(ae-aa)+sin(de)*sin(dp)]/c
@@ -802,7 +790,7 @@ def d_delay_astrometry_d_ELONG(self, toas, param="", acc_delay=None):
elonge = Earth elong
elatp = pulsar elat
elongp = pulsar elong
- r = distance from SSB to Earh
+ r = distance from SSB to Earth
c = speed of light
delay = r*[cos(elate)*cos(elatp)*cos(elonge-elongp)+sin(elate)*sin(elatp)]/c
@@ -969,8 +957,12 @@ def as_ECL(self, epoch=None, ecl="IERS2010"):
lat=self.ELAT.quantity,
obliquity=OBL[self.ECL.value],
obstime=self.POSEPOCH.quantity,
- pm_lon_coslat=self.PMELONG.uncertainty,
- pm_lat=self.PMELAT.uncertainty,
+ pm_lon_coslat=self.PMELONG.uncertainty
+ if self.PMELONG.uncertainty is not None
+ else 0 * self.PMELONG.units,
+ pm_lat=self.PMELAT.uncertainty
+ if self.PMELAT.uncertainty is not None
+ else 0 * self.PMELAT.units,
frame=PulsarEcliptic,
)
c_ECL = c.transform_to(PulsarEcliptic(ecl=ecl))
@@ -1019,8 +1011,12 @@ def as_ICRS(self, epoch=None):
lat=self.ELAT.quantity,
obliquity=OBL[self.ECL.value],
obstime=self.POSEPOCH.quantity,
- pm_lon_coslat=self.ELONG.uncertainty * np.cos(self.ELAT.quantity) / dt,
- pm_lat=self.ELAT.uncertainty / dt,
+ pm_lon_coslat=self.ELONG.uncertainty * np.cos(self.ELAT.quantity) / dt
+ if self.ELONG.uncertainty is not None
+ else 0 * self.ELONG.units / dt,
+ pm_lat=self.ELAT.uncertainty / dt
+ if self.ELAT.uncertainty is not None
+ else 0 * self.ELAT.units / dt,
frame=PulsarEcliptic,
)
c_ICRS = c.transform_to(coords.ICRS)
@@ -1033,8 +1029,12 @@ def as_ICRS(self, epoch=None):
lat=self.ELAT.quantity,
obliquity=OBL[self.ECL.value],
obstime=self.POSEPOCH.quantity,
- pm_lon_coslat=self.PMELONG.uncertainty,
- pm_lat=self.PMELAT.uncertainty,
+ pm_lon_coslat=self.PMELONG.uncertainty
+ if self.PMELONG.uncertainty is not None
+ else 0 * self.PMELONG.units,
+ pm_lat=self.PMELAT.uncertainty
+ if self.PMELAT.uncertainty is not None
+ else 0 * self.PMELAT.units,
frame=PulsarEcliptic,
)
c_ICRS = c.transform_to(coords.ICRS)
diff --git a/src/pint/models/binary_dd.py b/src/pint/models/binary_dd.py
index b9f86d115..ad855ffe8 100644
--- a/src/pint/models/binary_dd.py
+++ b/src/pint/models/binary_dd.py
@@ -1,7 +1,23 @@
"""Damour and Deruelle binary model."""
-from pint.models.parameter import floatParameter
+import numpy as np
+from astropy import units as u, constants as c
+
+from pint.models.parameter import floatParameter, funcParameter
from pint.models.pulsar_binary import PulsarBinary
from pint.models.stand_alone_psr_binaries.DD_model import DDmodel
+from pint.models.stand_alone_psr_binaries.DDS_model import DDSmodel
+from pint.models.stand_alone_psr_binaries.DDGR_model import DDGRmodel
+import pint.derived_quantities
+
+
+# these would be doable with lambda functions
+# but then the instances would not pickle
+def _sini_from_shapmax(SHAPMAX):
+ return 1 - np.exp(-SHAPMAX)
+
+
+def _mp_from_mtot(MTOT, M2):
+ return MTOT - M2
class BinaryDD(PulsarBinary):
@@ -70,6 +86,7 @@ def __init__(
self.add_param(
floatParameter(
name="DTH",
+ aliases=["DTHETA"],
value=0.0,
units="",
description="Relativistic deformation of the orbit",
@@ -83,16 +100,257 @@ def validate(self):
# If any *DOT is set, we need T0
for p in ("PBDOT", "OMDOT", "EDOT", "A1DOT"):
if hasattr(self, p) and getattr(self, p).value is None:
- getattr(self, p).set("0")
+ getattr(self, p).value = 0.0
getattr(self, p).frozen = True
- if self.GAMMA.value is None:
- self.GAMMA.set("0")
+ if hasattr(self, "GAMMA") and self.GAMMA.value is None:
+ self.GAMMA.value = 0.0
self.GAMMA.frozen = True
# If eccentricity is zero, freeze some parameters to 0
# OM = 0 -> T0 = TASC
if self.ECC.value == 0 or self.ECC.value is None:
for p in ("ECC", "OM", "OMDOT", "EDOT"):
- getattr(self, p).set("0")
- getattr(self, p).frozen = True
+ if hasattr(self, p):
+ getattr(self, p).value = 0.0
+ getattr(self, p).frozen = True
+
+
+class BinaryDDS(BinaryDD):
+ """Damour and Deruelle model with alternate Shapiro delay parameterization.
+
+ This extends the :class:`pint.models.binary_dd.BinaryDD` model with
+ :math:`SHAPMAX = -\log(1-s)` instead of just :math:`s=\sin i`, which behaves better
+ for :math:`\sin i` near 1. It does not (yet) implement the higher-order delays and lensing correction.
+
+ The actual calculations for this are done in
+ :class:`pint.models.stand_alone_psr_binaries.DDS_model.DDSmodel`.
+
+ It supports all the parameters defined in :class:`pint.models.pulsar_binary.PulsarBinary`
+ and :class:`pint.models.binary_dd.BinaryDD` plus:
+
+ SHAPMAX
+ :math:`-\log(1-\sin i)`
+
+ It also converts:
+
+ SINI
+ into a read-only parameter
+
+ Parameters supported:
+
+ .. paramtable::
+ :class: pint.models.binary_dd.BinaryDDS
+
+ References
+ ----------
+ - Kramer et al. (2006), Science, 314, 97 [ksm+2006]_
+ - Rafikov and Lai (2006), PRD, 73, 063003 [rl2006]_
+
+ .. [ksm+2006] https://ui.adsabs.harvard.edu/abs/2006Sci...314...97K/abstract
+ .. [rl2006] https://ui.adsabs.harvard.edu/abs/2006PhRvD..73f3003R/abstract
+
+ """
+
+ register = True
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ self.binary_model_name = "DDS"
+ self.binary_model_class = DDSmodel
+
+ self.add_param(
+ floatParameter(
+ name="SHAPMAX", value=0.0, description="Function of inclination angle"
+ )
+ )
+ self.remove_param("SINI")
+ self.add_param(
+ funcParameter(
+ name="SINI",
+ units="",
+ description="Sine of inclination angle",
+ params=("SHAPMAX",),
+ func=_sini_from_shapmax,
+ )
+ )
+
+ def validate(self):
+ """Validate parameters."""
+ super().validate()
+ if (
+ hasattr(self, "SHAPMAX")
+ and self.SHAPMAX.value is not None
+ and not self.SHAPMAX.value > -np.log(2)
+ ):
+ raise ValueError(f"SHAPMAX must be > -log(2) ({self.SHAPMAX.quantity})")
+
+
+class BinaryDDGR(BinaryDD):
+ """Damour and Deruelle model assuming GR to be correct
+
+ It supports all the parameters defined in :class:`pint.models.pulsar_binary.PulsarBinary`
+ and :class:`pint.models.binary_dd.BinaryDD` plus:
+
+ MTOT
+ Total mass
+ XPBDOT
+ Excess PBDOT beyond what GR predicts
+ XOMDOT
+ Excess OMDOT beyond what GR predicts
+
+ It also reads but converts:
+
+ SINI
+ PBDOT
+ OMDOT
+ GAMMA
+ DR
+ DTH
+
+ into read-only parameters
+
+ The actual calculations for this are done in
+ :class:`pint.models.stand_alone_psr_binaries.DDGR_model.DDGRmodel`.
+
+ Parameters supported:
+
+ .. paramtable::
+ :class: pint.models.binary_dd.BinaryDDGR
+
+ References
+ ----------
+ - Taylor and Weisberg (1989), ApJ, 345, 434 [tw89]_
+
+ .. [tw89] https://ui.adsabs.harvard.edu/abs/1989ApJ...345..434T/abstract
+ """
+
+ register = True
+
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ self.binary_model_name = "DDGR"
+ self.binary_model_class = DDGRmodel
+
+ self.add_param(
+ floatParameter(
+ name="MTOT",
+ units=u.M_sun,
+ description="Total system mass in units of Solar mass",
+ )
+ )
+ self.add_param(
+ floatParameter(
+ name="XOMDOT",
+ units="deg/year",
+ description="Excess longitude of periastron advance compared to GR",
+ long_double=True,
+ )
+ )
+ self.add_param(
+ floatParameter(
+ name="XPBDOT",
+ units=u.day / u.day,
+ description="Excess Orbital period derivative respect to time compared to GR",
+ unit_scale=True,
+ scale_factor=1e-12,
+ scale_threshold=1e-7,
+ )
+ )
+ for p in ["OMDOT", "PBDOT", "GAMMA", "SINI", "DR", "DTH"]:
+ self.remove_param(p)
+
+ self.add_param(
+ funcParameter(
+ name="MP",
+ units=u.M_sun,
+ description="Pulsar mass",
+ params=("MTOT", "M2"),
+ func=_mp_from_mtot,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="OMDOT",
+ units="deg/year",
+ description="Rate of advance of periastron",
+ long_double=True,
+ params=("MP", "M2", "PB", "ECC"),
+ func=pint.derived_quantities.omdot,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="SINI",
+ units="",
+ description="Sine of inclination angle",
+ params=("MP", "M2", "PB", "A1"),
+ func=pint.derived_quantities.sini,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="PBDOT",
+ units=u.day / u.day,
+ description="Orbital period derivative respect to time",
+ unit_scale=True,
+ scale_factor=1e-12,
+ scale_threshold=1e-7,
+ params=("MP", "M2", "PB", "ECC"),
+ func=pint.derived_quantities.pbdot,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="GAMMA",
+ units="second",
+ description="Time dilation & gravitational redshift",
+ params=("MP", "M2", "PB", "ECC"),
+ func=pint.derived_quantities.gamma,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="DR",
+ units="",
+ description="Relativistic deformation of the orbit",
+ params=("MP", "M2", "PB"),
+ func=pint.derived_quantities.dr,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="DTH",
+ aliases=["DTHETA"],
+ units="",
+ description="Relativistic deformation of the orbit",
+ params=("MP", "M2", "PB"),
+ func=pint.derived_quantities.dth,
+ )
+ )
+
+ def setup(self):
+ """Parameter setup."""
+ super().setup()
+
+ def validate(self):
+ """Validate parameters."""
+ super().validate()
+ aR = (c.G * self.MTOT.quantity * self.PB.quantity**2 / 4 / np.pi**2) ** (
+ 1.0 / 3
+ )
+ sini = (
+ self.A1.quantity * self.MTOT.quantity / aR / self.M2.quantity
+ ).decompose()
+ if sini > 1:
+ raise ValueError(
+ f"Inferred SINI must be <= 1 for DDGR model (MTOT={self.MTOT.quantity}, PB={self.PB.quantity}, A1={self.A1.quantity}, M2={self.M2.quantity} imply SINI={sini})"
+ )
+
+ def update_binary_object(self, toas, acc_delay=None):
+ super().update_binary_object(toas, acc_delay)
+ self.binary_instance._updatePK()
diff --git a/src/pint/models/binary_ddk.py b/src/pint/models/binary_ddk.py
index 645befda5..c734d83e3 100644
--- a/src/pint/models/binary_ddk.py
+++ b/src/pint/models/binary_ddk.py
@@ -200,14 +200,13 @@ def validate(self):
"No valid AstrometryEcliptic or AstrometryEquatorial component found"
)
- if hasattr(self._parent, "PX"):
- if self._parent.PX.value <= 0.0 or self._parent.PX.value is None:
- raise TimingModelError("DDK model needs a valid `PX` value.")
- else:
+ if not hasattr(self._parent, "PX"):
raise MissingParameter(
"Binary_DDK", "PX", "DDK model needs PX from" "Astrometry."
)
+ if self._parent.PX.value <= 0.0 or self._parent.PX.value is None:
+ raise TimingModelError("DDK model needs a valid `PX` value.")
if "A1DOT" in self.params and self.A1DOT.value != 0:
warnings.warn("Using A1DOT with a DDK model is not advised.")
diff --git a/src/pint/models/binary_ell1.py b/src/pint/models/binary_ell1.py
index 1bb1cf093..c0051a446 100644
--- a/src/pint/models/binary_ell1.py
+++ b/src/pint/models/binary_ell1.py
@@ -18,6 +18,7 @@
from pint.models.stand_alone_psr_binaries.ELL1k_model import ELL1kmodel
from pint.models.timing_model import MissingParameter
from pint.utils import taylor_horner_deriv
+from pint import Tsun
def _eps_to_e(eps1, eps2):
@@ -28,7 +29,21 @@ def _eps_to_om(eps1, eps2):
OM = np.arctan2(eps1, eps2)
if OM < 0:
OM += 360 * u.deg
- return OM
+ return OM.to(u.deg)
+
+
+def _epsdot_to_edot(eps1, eps2, eps1dot, eps2dot):
+ # Eqn. A14,A15 in Lange et al. inverted
+ ecc = np.sqrt(eps1**2 + eps2**2)
+ return (eps1dot * eps1 + eps2dot * eps2) / ecc
+
+
+def _epsdot_to_omdot(eps1, eps2, eps1dot, eps2dot):
+ # Eqn. A14,A15 in Lange et al. inverted
+ ecc = np.sqrt(eps1**2 + eps2**2)
+ return ((eps1dot * eps2 - eps2dot * eps1) / ecc**2).to(
+ u.deg / u.yr, equivalencies=u.dimensionless_angles()
+ )
def _tasc_to_T0(TASC, PB, eps1, eps2):
@@ -45,7 +60,7 @@ class BinaryELL1(PulsarBinary):
This binary model uses a rectangular representation for the eccentricity of an orbit,
resolving complexities that arise with periastron-based parameters in nearly-circular
- orbits. It also makes certain approximations that are invalid when the eccentricity
+ orbits. It also makes certain approximations (up to O(e^3)) that are invalid when the eccentricity
is "large"; what qualifies as "large" depends on your data quality. A formula exists
to determine when the approximations this model makes are sufficiently accurate.
@@ -63,8 +78,26 @@ class BinaryELL1(PulsarBinary):
References
----------
- Lange et al. (2001), MNRAS, 326 (1), 274–282 [1]_
+ - Zhu et al. (2019), MNRAS, 482 (3), 3249-3260 [2]_
+ - Fiore et al. (2023), arXiv:2305.13624 [astro-ph.HE] [3]_
+
+ .. [1] https://ui.adsabs.harvard.edu/abs/2019MNRAS.482.3249Z/abstract
+ .. [2] https://ui.adsabs.harvard.edu/abs/2001MNRAS.326..274L/abstract
+ .. [3] https://arxiv.org/abs/2305.13624
+
+ Notes
+ -----
+ This includes o(e^2) expression for Roemer delay from Norbert Wex and Weiwei Zhu
+ This is equation (1) of Zhu et al (2019) but with a corrected typo:
+ In the first line of that equation, ex->e1 and ey->e2
+ In the other lines, ex->e2 and ey->e1
+ See Email from NW and WZ to David Nice on 2019-Aug-08
+ The dre expression comes from NW and WZ; the derivatives
+ were calculated by hand for PINT
+
+ Also includes o(e^3) expression from equation (4) of Fiore et al. (2023)
+ (derivatives also calculated by hand)
- .. [1] https://ui.adsabs.harvard.edu/abs/2001MNRAS.326..274L/abstract
"""
register = True
@@ -139,6 +172,28 @@ def __init__(self):
func=_eps_to_om,
)
)
+ self.add_param(
+ funcParameter(
+ name="EDOT",
+ units="1/s",
+ description="Eccentricity derivative respect to time",
+ unit_scale=True,
+ scale_factor=1e-12,
+ scale_threshold=1e-7,
+ params=("EPS1", "EPS2", "EPS1DOT", "EPS2DOT"),
+ func=_epsdot_to_edot,
+ )
+ )
+ self.add_param(
+ funcParameter(
+ name="OMDOT",
+ units="deg/year",
+ description="Rate of advance of periastron",
+ long_double=True,
+ params=("EPS1", "EPS2", "EPS1DOT", "EPS2DOT"),
+ func=_epsdot_to_omdot,
+ )
+ )
# don't implement T0 yet since that is a MJDparameter at base
# and our funcParameters don't support that yet
# self.add_param(
diff --git a/src/pint/models/dispersion_model.py b/src/pint/models/dispersion_model.py
index 40dd831a5..40acfdaec 100644
--- a/src/pint/models/dispersion_model.py
+++ b/src/pint/models/dispersion_model.py
@@ -74,11 +74,7 @@ def dm_value(self, toas):
------
DM values at given TOAs in the unit of DM.
"""
- if isinstance(toas, Table):
- toas_table = toas
- else:
- toas_table = toas.table
-
+ toas_table = toas if isinstance(toas, Table) else toas.table
dm = np.zeros(len(toas_table)) * self._parent.DM.units
for dm_f in self.dm_value_funcs:
@@ -127,15 +123,10 @@ def register_dm_deriv_funcs(self, func, param):
if pn not in list(self.dm_deriv_funcs.keys()):
self.dm_deriv_funcs[pn] = [func]
+ elif func in self.dm_deriv_funcs[pn]:
+ return
else:
- # TODO:
- # Running setup() multiple times can lead to adding derivative
- # function multiple times. This prevent it from happening now. But
- # in the future, we should think a better way to do so.
- if func in self.dm_deriv_funcs[pn]:
- return
- else:
- self.dm_deriv_funcs[pn] += [func]
+ self.dm_deriv_funcs[pn] += [func]
class DispersionDM(Dispersion):
@@ -201,12 +192,12 @@ def validate(self):
if self.DMEPOCH.value is None:
# Copy PEPOCH (PEPOCH must be set!)
self.DMEPOCH.value = self._parent.PEPOCH.value
- if self.DMEPOCH.value is None:
- raise MissingParameter(
- "Dispersion",
- "DMEPOCH",
- "DMEPOCH or PEPOCH is required if DM1 or higher are set",
- )
+ if self.DMEPOCH.value is None:
+ raise MissingParameter(
+ "Dispersion",
+ "DMEPOCH",
+ "DMEPOCH or PEPOCH is required if DM1 or higher are set",
+ )
def DM_dervative_unit(self, n):
return "pc cm^-3/yr^%d" % n if n else "pc cm^-3"
@@ -242,13 +233,9 @@ def constant_dispersion_delay(self, toas, acc_delay=None):
return self.dispersion_type_delay(toas)
def print_par(self, format="pint"):
- # TODO we need to have a better design for print out the parameters in
- # an inheritance class.
- result = ""
prefix_dm = list(self.get_prefix_mapping_component("DM").values())
dms = ["DM"] + prefix_dm
- for dm in dms:
- result += getattr(self, dm).as_parfile_line(format=format)
+ result = "".join(getattr(self, dm).as_parfile_line(format=format) for dm in dms)
if hasattr(self, "components"):
all_params = self.components["DispersionDM"].params
else:
@@ -280,11 +267,7 @@ def d_dm_d_DMs(
DMEPOCH = self.DMEPOCH.value
dt = (toas["tdbld"] - DMEPOCH) * u.day
dt_value = (dt.to(u.yr)).value
- d_dm_d_dm_param = taylor_horner(dt_value, dm_terms) * (
- self.DM.units / par.units
- )
-
- return d_dm_d_dm_param
+ return taylor_horner(dt_value, dm_terms) * (self.DM.units / par.units)
def change_dmepoch(self, new_epoch):
"""Change DMEPOCH to a new value and update DM accordingly.
@@ -312,7 +295,7 @@ def change_dmepoch(self, new_epoch):
dt = (new_epoch.tdb.mjd_long - dmepoch_ld) * u.day
for n in range(len(dmterms) - 1):
- cur_deriv = self.DM if n == 0 else getattr(self, "DM{}".format(n))
+ cur_deriv = self.DM if n == 0 else getattr(self, f"DM{n}")
cur_deriv.value = taylor_horner_deriv(
dt.to(u.yr), dmterms, deriv_order=n + 1
)
@@ -391,8 +374,7 @@ def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True):
if int(index) in self.get_prefix_mapping_component("DMX_"):
raise ValueError(
- "Index '%s' is already in use in this model. Please choose another."
- % index
+ f"Index '{index}' is already in use in this model. Please choose another."
)
if isinstance(dmx, u.quantity.Quantity):
@@ -407,7 +389,7 @@ def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True):
mjd_end = mjd_end.value
self.add_param(
prefixParameter(
- name="DMX_" + i,
+ name=f"DMX_{i}",
units="pc cm^-3",
value=dmx,
description="Dispersion measure variation",
@@ -417,7 +399,7 @@ def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True):
)
self.add_param(
prefixParameter(
- name="DMXR1_" + i,
+ name=f"DMXR1_{i}",
units="MJD",
description="Beginning of DMX interval",
parameter_type="MJD",
@@ -427,7 +409,7 @@ def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True):
)
self.add_param(
prefixParameter(
- name="DMXR2_" + i,
+ name=f"DMXR2_{i}",
units="MJD",
description="End of DMX interval",
parameter_type="MJD",
@@ -508,8 +490,7 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru
raise ValueError("Only one MJD bound is set.")
if int(index) in dct:
raise ValueError(
- "Index '%s' is already in use in this model. Please choose another."
- % index
+ f"Index '{index}' is already in use in this model. Please choose another."
)
if isinstance(dmx, u.quantity.Quantity):
dmx = dmx.to_value(u.pc / u.cm**3)
@@ -524,7 +505,7 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru
log.trace(f"Adding DMX_{i} from MJD {mjd_start} to MJD {mjd_end}")
self.add_param(
prefixParameter(
- name="DMX_" + i,
+ name=f"DMX_{i}",
units="pc cm^-3",
value=dmx,
description="Dispersion measure variation",
@@ -534,7 +515,7 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru
)
self.add_param(
prefixParameter(
- name="DMXR1_" + i,
+ name=f"DMXR1_{i}",
units="MJD",
description="Beginning of DMX interval",
parameter_type="MJD",
@@ -544,7 +525,7 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru
)
self.add_param(
prefixParameter(
- name="DMXR2_" + i,
+ name=f"DMXR2_{i}",
units="MJD",
description="End of DMX interval",
parameter_type="MJD",
diff --git a/src/pint/models/frequency_dependent.py b/src/pint/models/frequency_dependent.py
index 0dcb774ca..c78359cd0 100644
--- a/src/pint/models/frequency_dependent.py
+++ b/src/pint/models/frequency_dependent.py
@@ -59,10 +59,9 @@ def setup(self):
def validate(self):
super().validate()
- FD_terms = list(self.get_prefix_mapping_component("FD").keys())
- FD_terms.sort()
+ FD_terms = sorted(self.get_prefix_mapping_component("FD").keys())
FD_in_order = list(range(1, max(FD_terms) + 1))
- if not FD_terms == FD_in_order:
+ if FD_terms != FD_in_order:
diff = list(set(FD_in_order) - set(FD_terms))
raise MissingParameter("FD", "FD%d" % diff[0])
diff --git a/src/pint/models/glitch.py b/src/pint/models/glitch.py
index 4676c78a3..d89eec6b8 100644
--- a/src/pint/models/glitch.py
+++ b/src/pint/models/glitch.py
@@ -137,11 +137,11 @@ def setup(self):
for idx in set(self.glitch_indices):
for param in self.glitch_prop:
if not hasattr(self, param + "%d" % idx):
- param0 = getattr(self, param + "1")
+ param0 = getattr(self, f"{param}1")
self.add_param(param0.new_param(idx))
getattr(self, param + "%d" % idx).value = 0.0
self.register_deriv_funcs(
- getattr(self, "d_phase_d_" + param[0:-1]), param + "%d" % idx
+ getattr(self, f"d_phase_d_{param[:-1]}"), param + "%d" % idx
)
def validate(self):
@@ -152,14 +152,15 @@ def validate(self):
msg = "Glitch Epoch is needed for Glitch %d." % idx
raise MissingParameter("Glitch", "GLEP_%d" % idx, msg)
else: # Check to see if both the epoch and phase are to be fit
- if hasattr(self, "GLPH_%d" % idx):
- if (not getattr(self, "GLEP_%d" % idx).frozen) and (
- not getattr(self, "GLPH_%d" % idx).frozen
- ):
- raise ValueError(
- "Both the glitch epoch and phase cannot be fit for Glitch %d."
- % idx
- )
+ if (
+ hasattr(self, "GLPH_%d" % idx)
+ and (not getattr(self, "GLEP_%d" % idx).frozen)
+ and (not getattr(self, "GLPH_%d" % idx).frozen)
+ ):
+ raise ValueError(
+ "Both the glitch epoch and phase cannot be fit for Glitch %d."
+ % idx
+ )
# Check the Decay Term.
glf0dparams = [x for x in self.params if x.startswith("GLF0D_")]
@@ -226,7 +227,7 @@ def deriv_prep(self, toas, param, delay):
"""Get the things we need for any of the derivative calcs"""
tbl = toas.table
p, ids, idv = split_prefixed_name(param)
- eph = getattr(self, "GLEP_" + ids).value
+ eph = getattr(self, f"GLEP_{ids}").value
dt = (tbl["tdbld"] - eph) * u.day - delay
dt = dt.to(u.second)
affected = np.where(dt > 0.0)[0]
@@ -237,7 +238,7 @@ def d_phase_d_GLPH(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLPH_":
raise ValueError(
- "Can not calculate d_phase_d_GLPH with respect to %s." % param
+ f"Can not calculate d_phase_d_GLPH with respect to {param}."
)
par_GLPH = getattr(self, param)
dpdGLPH = np.zeros(len(tbl), dtype=np.longdouble) / par_GLPH.units
@@ -249,7 +250,7 @@ def d_phase_d_GLF0(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLF0_":
raise ValueError(
- "Can not calculate d_phase_d_GLF0 with respect to %s." % param
+ f"Can not calculate d_phase_d_GLF0 with respect to {param}."
)
par_GLF0 = getattr(self, param)
dpdGLF0 = np.zeros(len(tbl), dtype=np.longdouble) / par_GLF0.units
@@ -261,7 +262,7 @@ def d_phase_d_GLF1(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLF1_":
raise ValueError(
- "Can not calculate d_phase_d_GLF1 with respect to %s." % param
+ f"Can not calculate d_phase_d_GLF1 with respect to {param}."
)
par_GLF1 = getattr(self, param)
dpdGLF1 = np.zeros(len(tbl), dtype=np.longdouble) / par_GLF1.units
@@ -273,7 +274,7 @@ def d_phase_d_GLF2(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLF2_":
raise ValueError(
- "Can not calculate d_phase_d_GLF2 with respect to %s." % param
+ f"Can not calculate d_phase_d_GLF2 with respect to {param}."
)
par_GLF2 = getattr(self, param)
dpdGLF2 = np.zeros(len(tbl), dtype=np.longdouble) / par_GLF2.units
@@ -287,7 +288,7 @@ def d_phase_d_GLF0D(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLF0D_":
raise ValueError(
- "Can not calculate d_phase_d_GLF0D with respect to %s." % param
+ f"Can not calculate d_phase_d_GLF0D with respect to {param}."
)
par_GLF0D = getattr(self, param)
tau = getattr(self, "GLTD_%d" % idv).quantity
@@ -300,12 +301,12 @@ def d_phase_d_GLTD(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLTD_":
raise ValueError(
- "Can not calculate d_phase_d_GLTD with respect to %s." % param
+ f"Can not calculate d_phase_d_GLTD with respect to {param}."
)
par_GLTD = getattr(self, param)
if par_GLTD.value == 0.0:
return np.zeros(len(tbl), dtype=np.longdouble) / par_GLTD.units
- glf0d = getattr(self, "GLF0D_" + ids).quantity
+ glf0d = getattr(self, f"GLF0D_{ids}").quantity
tau = par_GLTD.quantity
dpdGLTD = np.zeros(len(tbl), dtype=np.longdouble) / par_GLTD.units
dpdGLTD[affected] += glf0d * (
@@ -318,14 +319,14 @@ def d_phase_d_GLEP(self, toas, param, delay):
tbl, p, ids, idv, dt, affected = self.deriv_prep(toas, param, delay)
if p != "GLEP_":
raise ValueError(
- "Can not calculate d_phase_d_GLEP with respect to %s." % param
+ f"Can not calculate d_phase_d_GLEP with respect to {param}."
)
par_GLEP = getattr(self, param)
- glf0 = getattr(self, "GLF0_" + ids).quantity
- glf1 = getattr(self, "GLF1_" + ids).quantity
- glf2 = getattr(self, "GLF2_" + ids).quantity
- glf0d = getattr(self, "GLF0D_" + ids).quantity
- tau = getattr(self, "GLTD_" + ids).quantity
+ glf0 = getattr(self, f"GLF0_{ids}").quantity
+ glf1 = getattr(self, f"GLF1_{ids}").quantity
+ glf2 = getattr(self, f"GLF2_{ids}").quantity
+ glf0d = getattr(self, f"GLF0D_{ids}").quantity
+ tau = getattr(self, f"GLTD_{ids}").quantity
dpdGLEP = np.zeros(len(tbl), dtype=np.longdouble) / par_GLEP.units
dpdGLEP[affected] += (
-glf0 + -glf1 * dt[affected] + -0.5 * glf2 * dt[affected] ** 2
diff --git a/src/pint/models/ifunc.py b/src/pint/models/ifunc.py
index 70e79efe2..91a3b79cb 100644
--- a/src/pint/models/ifunc.py
+++ b/src/pint/models/ifunc.py
@@ -138,5 +138,4 @@ def ifunc_phase(self, toas, delays):
else:
raise ValueError(f"Interpolation type {itype} not supported.")
- phase = ((times * u.s) * self._parent.F0.quantity).to(u.dimensionless_unscaled)
- return phase
+ return ((times * u.s) * self._parent.F0.quantity).to(u.dimensionless_unscaled)
diff --git a/src/pint/models/jump.py b/src/pint/models/jump.py
index 7c7095fef..17e231d25 100644
--- a/src/pint/models/jump.py
+++ b/src/pint/models/jump.py
@@ -117,7 +117,11 @@ def jump_phase(self, toas, delay):
F0.
"""
tbl = toas.table
- jphase = numpy.zeros(len(tbl)) * (self.JUMP1.units * self._parent.F0.units)
+ # base this on the first available jump (doesn't have to be JUMP1)
+ jphase = numpy.zeros(len(tbl)) * (
+ getattr(self, self.get_params_of_type("maskParameter")[0]).units
+ * self._parent.F0.units
+ )
for jump in self.jumps:
jump_par = getattr(self, jump)
mask = jump_par.select_toa_mask(toas)
diff --git a/src/pint/models/model_builder.py b/src/pint/models/model_builder.py
index a90383f85..a8417ff76 100644
--- a/src/pint/models/model_builder.py
+++ b/src/pint/models/model_builder.py
@@ -1,9 +1,14 @@
+"""Building a timing model from a par file."""
+
import copy
import warnings
from io import StringIO
from collections import Counter, defaultdict
from pathlib import Path
+from astropy import units as u
+from loguru import logger as log
+from pint.models.astrometry import Astrometry
from pint.models.parameter import maskParameter
from pint.models.timing_model import (
DEFAULT_ORDER,
@@ -20,8 +25,14 @@
ignore_prefix,
)
from pint.toa import get_TOAs
-from pint.utils import PrefixError, interesting_lines, lines_of, split_prefixed_name
-
+from pint.utils import (
+ PrefixError,
+ interesting_lines,
+ lines_of,
+ split_prefixed_name,
+ get_unit,
+)
+from pint.models.tcb_conversion import convert_tcb_tdb
__all__ = ["ModelBuilder", "get_model", "get_model_and_toas"]
@@ -71,9 +82,16 @@ def __init__(self):
# Validate the components
self.all_components = AllComponents()
self._validate_components()
- self.default_components = ["SolarSystemShapiro"]
+ self.default_components = []
- def __call__(self, parfile, allow_name_mixing=False):
+ def __call__(
+ self,
+ parfile,
+ allow_name_mixing=False,
+ allow_tcb=False,
+ toas_for_tzr=None,
+ **kwargs,
+ ):
"""Callable object for making a timing model from .par file.
Parameters
@@ -87,28 +105,98 @@ def __call__(self, parfile, allow_name_mixing=False):
T2EFAC and EFAC, both of them maps to PINT parameter EFAC, present
in the parfile at the same time.
+ allow_tcb : True, False, or "raw", optional
+ Whether to read TCB par files. Default is False, and will throw an
+ error upon encountering TCB par files. If True, the par file will be
+ converted to TDB upon read. If "raw", an unconverted malformed TCB
+ TimingModel object will be returned.
+
+ toas_for_tzr : TOAs or None, optional
+ If this is not None, a TZR TOA (AbsPhase) will be created using the
+ given TOAs object.
+
+ kwargs : dict
+ Any additional parameter/value pairs that will add to or override those in the parfile.
+
Returns
-------
pint.models.timing_model.TimingModel
The result timing model based on the input .parfile or file object.
"""
+
+ assert allow_tcb in [True, False, "raw"]
+ convert_tcb = allow_tcb == True
+ allow_tcb_ = allow_tcb in [True, "raw"]
+
pint_param_dict, original_name, unknown_param = self._pintify_parfile(
parfile, allow_name_mixing
)
+ remaining_args = {}
+ for k, v in kwargs.items():
+ if k not in pint_param_dict:
+ if isinstance(v, u.Quantity):
+ pint_param_dict[k] = [
+ str(v.to_value(get_unit(k))),
+ ]
+ else:
+ pint_param_dict[k] = [
+ str(v),
+ ]
+ original_name[k] = k
+ else:
+ remaining_args[k] = v
selected, conflict, param_not_in_pint = self.choose_model(pint_param_dict)
selected.update(set(self.default_components))
+
+ # Add SolarSystemShapiro only if an Astrometry component is present.
+ if any(
+ isinstance(self.all_components.components[sc], Astrometry)
+ for sc in selected
+ ):
+ selected.add("SolarSystemShapiro")
+
# Report conflict
if len(conflict) != 0:
self._report_conflict(conflict)
# Make timing model
cps = [self.all_components.components[c] for c in selected]
tm = TimingModel(components=cps)
- self._setup_model(tm, pint_param_dict, original_name, setup=True, validate=True)
+ self._setup_model(
+ tm,
+ pint_param_dict,
+ original_name,
+ setup=True,
+ validate=True,
+ allow_tcb=allow_tcb_,
+ )
# Report unknown line
for k, v in unknown_param.items():
p_line = " ".join([k] + v)
warnings.warn(f"Unrecognized parfile line '{p_line}'", UserWarning)
# log.warning(f"Unrecognized parfile line '{p_line}'")
+
+ if tm.UNITS.value == "TCB" and convert_tcb:
+ convert_tcb_tdb(tm)
+
+ for k, v in remaining_args.items():
+ if not hasattr(tm, k):
+ raise ValueError(f"Model does not have parameter '{k}'")
+ log.debug(f"Overriding '{k}' to '{v}'")
+ if isinstance(v, u.Quantity):
+ getattr(tm, k).quantity = v
+ else:
+ getattr(tm, k).value = v
+
+ # Explicitly add a TZR TOA from a given TOAs object.
+ if "AbsPhase" not in tm.components and toas_for_tzr is not None:
+ log.info("Creating a TZR TOA (AbsPhase) using the given TOAs object.")
+ tm.add_tzr_toa(toas_for_tzr)
+
+ if not hasattr(tm, "DelayComponent_list"):
+ setattr(tm, "DelayComponent_list", [])
+ if not hasattr(tm, "NoiseComponent_list"):
+ setattr(tm, "NoiseComponent_list", [])
+
return tm
def _validate_components(self):
@@ -166,9 +254,9 @@ def _get_component_param_overlap(self, component):
# Add aliases compare
overlap = in_param & cpm_param
# translate to PINT parameter
- overlap_pint_par = set(
- [self.all_components.alias_to_pint_param(ovlp)[0] for ovlp in overlap]
- )
+ overlap_pint_par = {
+ self.all_components.alias_to_pint_param(ovlp)[0] for ovlp in overlap
+ }
# The degree of overlapping for input component and compared component
overlap_deg_in = len(component.params) - len(overlap_pint_par)
overlap_deg_cpm = len(cp.params) - len(overlap_pint_par)
@@ -246,38 +334,40 @@ def _pintify_parfile(self, parfile, allow_name_mixing=False):
try:
pint_name, init0 = self.all_components.alias_to_pint_param(k)
except UnknownParameter:
- if k in ignore_params: # Parameter is known but in the ingore list
+ if k in ignore_params:
+ # Parameter is known but in the ignore list
continue
- else: # Check ignored prefix
- try:
- pfx, idxs, idx = split_prefixed_name(k)
- if pfx in ignore_prefix: # It is an ignored prefix.
- continue
- else:
- unknown_param[k] += v
- except PrefixError:
+ # Check ignored prefix
+ try:
+ pfx, idxs, idx = split_prefixed_name(k)
+ if pfx in ignore_prefix: # It is an ignored prefix.
+ continue
+ else:
unknown_param[k] += v
+ except PrefixError:
+ unknown_param[k] += v
continue
pint_param_dict[pint_name] += v
original_name_map[pint_name].append(k)
repeating[pint_name] += len(v)
# Check if this parameter is allowed to be repeated by PINT
- if len(pint_param_dict[pint_name]) > 1:
- if pint_name not in self.all_components.repeatable_param:
- raise TimingModelError(
- f"Parameter {pint_name} is not a repeatable parameter. "
- f"However, multiple line use it."
- )
+ if (
+ len(pint_param_dict[pint_name]) > 1
+ and pint_name not in self.all_components.repeatable_param
+ ):
+ raise TimingModelError(
+ f"Parameter {pint_name} is not a repeatable parameter. "
+ f"However, multiple line use it."
+ )
# Check if the name is mixed
for p_n, o_n in original_name_map.items():
- if len(o_n) > 1:
- if not allow_name_mixing:
- raise TimingModelError(
- f"Parameter {p_n} have mixed input names/alias "
- f"{o_n}. If you want to have mixing names, please use"
- f" 'allow_name_mixing=True', and the output .par file "
- f"will use '{original_name_map[pint_name][0]}'."
- )
+ if len(o_n) > 1 and not allow_name_mixing:
+ raise TimingModelError(
+ f"Parameter {p_n} have mixed input names/alias "
+ f"{o_n}. If you want to have mixing names, please use"
+ f" 'allow_name_mixing=True', and the output .par file "
+ f"will use '{original_name_map[pint_name][0]}'."
+ )
original_name_map[p_n] = o_n[0]
return pint_param_dict, original_name_map, unknown_param
@@ -347,12 +437,11 @@ def choose_model(self, param_inpar):
if p_name != first_init:
param_not_in_pint.append(pp)
- p_cp = self.all_components.param_component_map.get(first_init, None)
- if p_cp:
+ if p_cp := self.all_components.param_component_map.get(first_init, None):
param_components_inpar[p_name] = p_cp
# Back map the possible_components and the parameters in the parfile
# This will remove the duplicate components.
- conflict_components = defaultdict(set) # graph for confilict
+ conflict_components = defaultdict(set) # graph for conflict
for k, cps in param_components_inpar.items():
# If `timing_model` in param --> component mapping skip
# Timing model is the base.
@@ -386,10 +475,10 @@ def choose_model(self, param_inpar):
temp_cf_cp.remove(cp)
conflict_components[cp].update(set(temp_cf_cp))
continue
- # Check if the selected component in the confilict graph. If it is
- # remove the selected componens with its conflict components.
+ # Check if the selected component in the conflict graph. If it is
+ # remove the selected components with its conflict components.
for ps_cp in selected_components:
- cf_cps = conflict_components.get(ps_cp, None)
+ cf_cps = conflict_components.get(ps_cp)
if cf_cps is not None: # Had conflict, but resolved.
for cf_cp in cf_cps:
del conflict_components[cf_cp]
@@ -398,17 +487,17 @@ def choose_model(self, param_inpar):
selected_cates = {}
for cp in selected_components:
cate = self.all_components.component_category_map[cp]
- if cate not in selected_cates.keys():
- selected_cates[cate] = cp
- else:
- exisit_cp = selected_cates[cate]
+ if cate in selected_cates:
+ exist_cp = selected_cates[cate]
raise TimingModelError(
- f"Component '{cp}' and '{exisit_cp}' belong to the"
+ f"Component '{cp}' and '{exist_cp}' belong to the"
f" same category '{cate}'. Only one component from"
f" the same category can be used for a timing model."
f" Please check your input (e.g., .par file)."
)
+ else:
+ selected_cates[cate] = cp
return selected_components, conflict_components, param_not_in_pint
def _setup_model(
@@ -418,6 +507,7 @@ def _setup_model(
original_name=None,
setup=True,
validate=True,
+ allow_tcb=False,
):
"""Fill up a timing model with parameter values and then setup the model.
@@ -443,29 +533,28 @@ def _setup_model(
Whether to run the setup function in the timing model.
validate : bool, optional
Whether to run the validate function in the timing model.
+ allow_tcb : bool, optional
+ Whether to allow reading TCB par files
"""
- if original_name is not None:
- use_alias = True
- else:
- use_alias = False
+ use_alias = original_name is not None
for pp, v in pint_param_dict.items():
try:
par = getattr(timing_model, pp)
except AttributeError:
- # since the input is pintfied, it should be an uninitized indexed parameter
+ # since the input is pintfied, it should be an uninitialized indexed parameter
# double check if the missing parameter an indexed parameter.
pint_par, first_init = self.all_components.alias_to_pint_param(pp)
try:
prefix, _, index = split_prefixed_name(pint_par)
- except PrefixError:
+ except PrefixError as e:
par_hosts = self.all_components.param_component_map[pint_par]
- currnt_cp = timing_model.components.keys()
+ current_cp = timing_model.components.keys()
raise TimingModelError(
f"Parameter {pint_par} is recognized"
f" by PINT, but not used in the current"
f" timing model. It is used in {par_hosts},"
- f" but the current timing model uses {currnt_cp}."
- )
+ f" but the current timing model uses {current_cp}."
+ ) from e
# TODO need to create a better API for _locate_param_host
host_component = timing_model._locate_param_host(first_init)
timing_model.add_param_from_top(
@@ -477,10 +566,7 @@ def _setup_model(
# Fill up the values
param_line = len(v)
if param_line < 2:
- if use_alias:
- name = original_name[pp]
- else:
- name = pp
+ name = original_name[pp] if use_alias else pp
par.from_parfile_line(" ".join([name] + v))
else: # For the repeatable parameters
lines = copy.deepcopy(v) # Line queue.
@@ -516,20 +602,20 @@ def _setup_model(
# There is no current repeatable parameter matching the new line
# First try to fill up an empty space.
- if empty_repeat_param != []:
- emt_par = empty_repeat_param.pop(0)
- emt_par.from_parfile_line(" ".join([emt_par.name, li]))
- if use_alias: # Use the input alias as input
- emt_par.use_alias = original_name[pp]
- else:
+ if not empty_repeat_param:
# No empty space, add a new parameter to the timing model.
host_component = timing_model._locate_param_host(pp)
timing_model.add_param_from_top(temp_par, host_component[0][0])
+ else:
+ emt_par = empty_repeat_param.pop(0)
+ emt_par.from_parfile_line(" ".join([emt_par.name, li]))
+ if use_alias: # Use the input alias as input
+ emt_par.use_alias = original_name[pp]
if setup:
timing_model.setup()
if validate:
- timing_model.validate()
+ timing_model.validate(allow_tcb=allow_tcb)
return timing_model
def _report_conflict(self, conflict_graph):
@@ -538,12 +624,12 @@ def _report_conflict(self, conflict_graph):
# Put all the conflict components together from the graph
cf_cps = list(v)
cf_cps.append(k)
- raise ComponentConflict(
- "Can not decide the one component from:" " {}".format(cf_cps)
- )
+ raise ComponentConflict(f"Can not decide the one component from: {cf_cps}")
-def get_model(parfile, allow_name_mixing=False):
+def get_model(
+ parfile, allow_name_mixing=False, allow_tcb=False, toas_for_tzr=None, **kwargs
+):
"""A one step function to build model from a parfile.
Parameters
@@ -557,6 +643,19 @@ def get_model(parfile, allow_name_mixing=False):
T2EFAC and EFAC, both of them maps to PINT parameter EFAC, present
in the parfile at the same time.
+ allow_tcb : True, False, or "raw", optional
+ Whether to read TCB par files. Default is False, and will throw an
+ error upon encountering TCB par files. If True, the par file will be
+ converted to TDB upon read. If "raw", an unconverted malformed TCB
+ TimingModel object will be returned.
+
+ toas_for_tzr : TOAs or None, optional
+ If this is not None, a TZR TOA (AbsPhase) will be created using the
+ given TOAs object.
+
+ kwargs : dict
+ Any additional parameter/value pairs that will add to or override those in the parfile.
+
Returns
-------
Model instance get from parfile.
@@ -566,16 +665,28 @@ def get_model(parfile, allow_name_mixing=False):
contents = parfile.read()
except AttributeError:
contents = None
- if contents is None:
- # # parfile is a filename and can be handled by ModelBuilder
- # if _model_builder is None:
- # _model_builder = ModelBuilder()
- model = model_builder(parfile, allow_name_mixing)
- model.name = parfile
- return model
- else:
- tm = model_builder(StringIO(contents), allow_name_mixing)
- return tm
+ if contents is not None:
+ return model_builder(
+ StringIO(contents),
+ allow_name_mixing,
+ allow_tcb=allow_tcb,
+ toas_for_tzr=toas_for_tzr,
+ **kwargs,
+ )
+
+ # # parfile is a filename and can be handled by ModelBuilder
+ # if _model_builder is None:
+ # _model_builder = ModelBuilder()
+ model = model_builder(
+ parfile,
+ allow_name_mixing,
+ allow_tcb=allow_tcb,
+ toas_for_tzr=toas_for_tzr,
+ **kwargs,
+ )
+ model.name = parfile
+
+ return model
def get_model_and_toas(
@@ -592,6 +703,9 @@ def get_model_and_toas(
picklefilename=None,
allow_name_mixing=False,
limits="warn",
+ allow_tcb=False,
+ add_tzr_to_model=True,
+ **kwargs,
):
"""Load a timing model and a related TOAs, using model commands as needed
@@ -601,6 +715,9 @@ def get_model_and_toas(
The parfile name, or a file-like object to read the parfile contents from
timfile : str
The timfile name, or a file-like object to read the timfile contents from
+ ephem : str, optional
+ If not None (default), this ephemeris will be used to create the TOAs object.
+ Default is to use the EPHEM parameter from the timing model.
include_bipm : bool or None
Whether to apply the BIPM clock correction. Defaults to True.
bipm_version : string or None
@@ -630,12 +747,24 @@ def get_model_and_toas(
in the parfile at the same time.
limits : "warn" or "error"
What to do when encountering TOAs for which clock corrections are not available.
+ allow_tcb : True, False, or "raw", optional
+ Whether to read TCB par files. Default is False, and will throw an
+ error upon encountering TCB par files. If True, the par file will be
+ converted to TDB upon read. If "raw", an unconverted malformed TCB
+ TimingModel object will be returned.
+ add_tzr_to_model : bool, optional
+ Create a TZR TOA in the timing model using the created TOAs object. Default is
+ True.
+ kwargs : dict
+ Any additional parameter/value pairs that will add to or override those in the parfile.
Returns
-------
A tuple with (model instance, TOAs instance)
"""
- mm = get_model(parfile, allow_name_mixing)
+
+ mm = get_model(parfile, allow_name_mixing, allow_tcb=allow_tcb, **kwargs)
+
tt = get_TOAs(
timfile,
include_pn=include_pn,
@@ -650,4 +779,9 @@ def get_model_and_toas(
picklefilename=picklefilename,
limits=limits,
)
+
+ if "AbsPhase" not in mm.components and add_tzr_to_model:
+ log.info("Creating a TZR TOA (AbsPhase) using the given TOAs object.")
+ mm.add_tzr_toa(tt)
+
return mm, tt
diff --git a/src/pint/models/noise_model.py b/src/pint/models/noise_model.py
index a173f45a6..02a5d291a 100644
--- a/src/pint/models/noise_model.py
+++ b/src/pint/models/noise_model.py
@@ -89,7 +89,6 @@ def __init__(
def setup(self):
super().setup()
- # Get all the EFAC parameters and EQUAD
self.EFACs = {}
self.EQUADs = {}
self.TNEQs = {}
@@ -107,23 +106,17 @@ def setup(self):
continue
# convert all the TNEQ to EQUAD
- for tneq in self.TNEQs:
+ for tneq, value in self.TNEQs.items():
tneq_par = getattr(self, tneq)
if tneq_par.key is None:
continue
- if self.TNEQs[tneq] in list(self.EQUADs.values()):
+ if value in list(self.EQUADs.values()):
log.warning(
- "'%s %s %s' is provided by parameter EQUAD, using"
- " EQUAD instead. " % (tneq, tneq_par.key, tneq_par.key_value)
+ f"'{tneq} {tneq_par.key} {tneq_par.key_value}' is provided by parameter EQUAD, using EQUAD instead. "
)
else:
- EQUAD_name = "EQUAD" + str(tneq_par.index)
- if EQUAD_name in list(self.EQUADs.keys()):
- EQUAD_par = getattr(self, EQUAD_name)
- EQUAD_par.key = tneq_par.key
- EQUAD_par.key_value = tneq_par.key_value
- EQUAD_par.quantity = tneq_par.quantity.to(u.us)
- else:
+ EQUAD_name = f"EQUAD{str(tneq_par.index)}"
+ if EQUAD_name not in list(self.EQUADs.keys()):
self.add_param(
maskParameter(
name="EQUAD",
@@ -135,10 +128,10 @@ def setup(self):
" scaled (by EFAC) TOA uncertainty.",
)
)
- EQUAD_par = getattr(self, EQUAD_name)
- EQUAD_par.key = tneq_par.key
- EQUAD_par.key_value = tneq_par.key_value
- EQUAD_par.quantity = tneq_par.quantity.to(u.us)
+ EQUAD_par = getattr(self, EQUAD_name)
+ EQUAD_par.quantity = tneq_par.quantity.to(u.us)
+ EQUAD_par.key_value = tneq_par.key_value
+ EQUAD_par.key = tneq_par.key
for pp in self.params:
if pp.startswith("EQUAD"):
par = getattr(self, pp)
@@ -150,7 +143,7 @@ def validate(self):
for el in ["EFACs", "EQUADs"]:
l = list(getattr(self, el).values())
if [x for x in l if l.count(x) > 1] != []:
- raise ValueError("'%s' have duplicated keys and key values." % el)
+ raise ValueError(f"'{el}' have duplicated keys and key values.")
def scale_toa_sigma(self, toas):
sigma_scaled = toas.table["error"].quantity.copy()
@@ -250,7 +243,7 @@ def validate(self):
for el in ["DMEFACs", "DMEQUADs"]:
l = list(getattr(self, el).values())
if [x for x in l if l.count(x) > 1] != []:
- raise ValueError("'%s' have duplicated keys and key values." % el)
+ raise ValueError(f"'{el}' have duplicated keys and key values.")
def scale_dm_sigma(self, toas):
"""
@@ -341,13 +334,10 @@ def validate(self):
for el in ["ECORRs"]:
l = list(getattr(self, el).values())
if [x for x in l if l.count(x) > 1] != []:
- raise ValueError("'%s' have duplicated keys and key values." % el)
+ raise ValueError(f"'{el}' have duplicated keys and key values.")
def get_ecorrs(self):
- ecorrs = []
- for ecorr, ecorr_key in list(self.ECORRs.items()):
- ecorrs.append(getattr(self, ecorr))
- return ecorrs
+ return [getattr(self, ecorr) for ecorr, ecorr_key in list(self.ECORRs.items())]
def get_noise_basis(self, toas):
"""Return the quantization matrix for ECORR.
@@ -497,8 +487,7 @@ def get_noise_weights(self, toas):
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
amp, gam, nf = self.get_pl_vals()
Ffreqs = get_rednoise_freqs(t, nf)
- weights = powerlaw(Ffreqs, amp, gam) * Ffreqs[0]
- return weights
+ return powerlaw(Ffreqs, amp, gam) * Ffreqs[0]
def pl_dm_basis_weight_pair(self, toas):
"""Return a Fourier design matrix and power law DM noise weights.
@@ -613,8 +602,7 @@ def get_noise_basis(self, toas):
tbl = toas.table
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
nf = self.get_pl_vals()[2]
- Fmat = create_fourier_design_matrix(t, nf)
- return Fmat
+ return create_fourier_design_matrix(t, nf)
def get_noise_weights(self, toas):
"""Return power law red noise weights.
@@ -625,8 +613,7 @@ def get_noise_weights(self, toas):
t = (tbl["tdbld"].quantity * u.day).to(u.s).value
amp, gam, nf = self.get_pl_vals()
Ffreqs = get_rednoise_freqs(t, nf)
- weights = powerlaw(Ffreqs, amp, gam) * Ffreqs[0]
- return weights
+ return powerlaw(Ffreqs, amp, gam) * Ffreqs[0]
def pl_rn_basis_weight_pair(self, toas):
"""Return a Fourier design matrix and power law red noise weights.
@@ -662,9 +649,7 @@ def get_ecorr_epochs(toas_table, dt=1, nmin=2):
bucket_ref.append(toas_table[i])
bucket_ind.append([i])
- bucket_ind2 = [ind for ind in bucket_ind if len(ind) >= nmin]
-
- return bucket_ind2
+ return [ind for ind in bucket_ind if len(ind) >= nmin]
def get_ecorr_nweights(toas_table, dt=1, nmin=2):
@@ -689,15 +674,12 @@ def create_ecorr_quantization_matrix(toas_table, dt=1, nmin=2):
def get_rednoise_freqs(t, nmodes, Tspan=None):
"""Frequency components for creating the red noise basis matrix."""
- if Tspan is not None:
- T = Tspan
- else:
- T = t.max() - t.min()
+ T = Tspan if Tspan is not None else t.max() - t.min()
f = np.linspace(1 / T, nmodes / T, nmodes)
Ffreqs = np.zeros(2 * nmodes)
- Ffreqs[0::2] = f
+ Ffreqs[::2] = f
Ffreqs[1::2] = f
return Ffreqs
@@ -719,7 +701,7 @@ def create_fourier_design_matrix(t, nmodes, Tspan=None):
Ffreqs = get_rednoise_freqs(t, nmodes, Tspan=Tspan)
- F[:, ::2] = np.sin(2 * np.pi * t[:, None] * Ffreqs[0::2])
+ F[:, ::2] = np.sin(2 * np.pi * t[:, None] * Ffreqs[::2])
F[:, 1::2] = np.cos(2 * np.pi * t[:, None] * Ffreqs[1::2])
return F
diff --git a/src/pint/models/parameter.py b/src/pint/models/parameter.py
index 5543fbda3..8a14e0106 100644
--- a/src/pint/models/parameter.py
+++ b/src/pint/models/parameter.py
@@ -28,6 +28,7 @@
import astropy.units as u
import numpy as np
from astropy.coordinates.angles import Angle
+from uncertainties import ufloat
from loguru import logger as log
@@ -213,9 +214,8 @@ def quantity(self, val):
if val is None:
if hasattr(self, "quantity") and self.quantity is not None:
raise ValueError("Setting an existing value to None is not allowed.")
- else:
- self._quantity = val
- return
+ self._quantity = val
+ return
self._quantity = self._set_quantity(val)
@property
@@ -226,10 +226,7 @@ def value(self):
a :class:`~astropy.units.Quantity` can be provided, which will be converted
to ``self.units``.
"""
- if self._quantity is None:
- return None
- else:
- return self._get_value(self._quantity)
+ return None if self._quantity is None else self._get_value(self._quantity)
@value.setter
def value(self, val):
@@ -256,20 +253,19 @@ def units(self):
@units.setter
def units(self, unt):
# Check if this is the first time set units and check compatibility
- if hasattr(self, "quantity"):
- if self.units is not None:
- if unt != self.units:
- wmsg = "Parameter " + self.name + " default units has been "
- wmsg += " reset to " + str(unt) + " from " + str(self.units)
- log.warning(wmsg)
- try:
- if hasattr(self.quantity, "unit"):
- self.quantity.to(unt)
- except ValueError:
- log.warning(
- "The value unit is not compatible with"
- " parameter units right now."
- )
+ if hasattr(self, "quantity") and self.units is not None:
+ if unt != self.units:
+ wmsg = f"Parameter {self.name} default units has been "
+ wmsg += f" reset to {str(unt)} from {str(self.units)}"
+ log.warning(wmsg)
+ try:
+ if hasattr(self.quantity, "unit"):
+ self.quantity.to(unt)
+ except ValueError:
+ log.warning(
+ "The value unit is not compatible with"
+ " parameter units right now."
+ )
if unt is None:
self._units = None
@@ -306,12 +302,12 @@ def uncertainty(self, val):
raise ValueError(
"Setting an existing uncertainty to None is not allowed."
)
- else:
- self._uncertainty = self._uncertainty_value = None
- return
+ self._uncertainty = self._uncertainty_value = None
+ return
+
val = self._set_uncertainty(val)
- if not val >= 0:
+ if val < 0:
raise ValueError(f"Uncertainties cannot be negative but {val} was supplied")
# self.uncertainty_value = np.abs(self.uncertainty_value)
@@ -416,16 +412,16 @@ def _print_uncertainty(self, uncertainty):
return str(uncertainty.to(self.units).value)
def __repr__(self):
- out = "{0:16s}{1:20s}".format(self.__class__.__name__ + "(", self.name)
+ out = "{0:16s}{1:20s}".format(f"{self.__class__.__name__}(", self.name)
if self.quantity is None:
out += "UNSET"
return out
out += "{:17s}".format(self.str_quantity(self.quantity))
if self.units is not None:
- out += " (" + str(self.units) + ")"
+ out += f" ({str(self.units)})"
if self.uncertainty is not None and isinstance(self.value, numbers.Number):
- out += " +/- " + str(self.uncertainty.to(self.units))
- out += " frozen={}".format(self.frozen)
+ out += f" +/- {str(self.uncertainty.to(self.units))}"
+ out += f" frozen={self.frozen}"
out += ")"
return out
@@ -433,7 +429,7 @@ def help_line(self):
"""Return a help line containing parameter name, description and units."""
out = "%-12s %s" % (self.name, self.description)
if self.units is not None:
- out += " (" + str(self.units) + ")"
+ out += f" ({str(self.units)})"
return out
def as_parfile_line(self, format="pint"):
@@ -443,7 +439,7 @@ def as_parfile_line(self, format="pint"):
----------
format : str, optional
Parfile output format. PINT outputs in 'tempo', 'tempo2' and 'pint'
- formats. The defaul format is `pint`.
+ formats. The default format is `pint`.
Returns
-------
@@ -458,28 +454,25 @@ def as_parfile_line(self, format="pint"):
assert (
format.lower() in _parfile_formats
), "parfile format must be one of %s" % ", ".join(
- ['"%s"' % x for x in _parfile_formats]
+ [f'"{x}"' for x in _parfile_formats]
)
# Don't print unset parameters
if self.quantity is None:
return ""
- if self.use_alias is None:
- name = self.name
- else:
- name = self.use_alias
+ name = self.name if self.use_alias is None else self.use_alias
# special cases for parameter names that change depending on format
- if self.name == "CHI2" and not (format.lower() == "pint"):
+ if self.name == "CHI2" and format.lower() != "pint":
# no CHI2 for TEMPO/TEMPO2
return ""
- elif self.name == "SWM" and not (format.lower() == "pint"):
+ elif self.name == "SWM" and format.lower() != "pint":
# no SWM for TEMPO/TEMPO2
return ""
- elif self.name == "A1DOT" and not (format.lower() == "pint"):
+ elif self.name == "A1DOT" and format.lower() != "pint":
# change to XDOT for TEMPO/TEMPO2
name = "XDOT"
- elif self.name == "STIGMA" and not (format.lower() == "pint"):
+ elif self.name == "STIGMA" and format.lower() != "pint":
# change to VARSIGMA for TEMPO/TEMPO2
name = "VARSIGMA"
@@ -493,22 +486,22 @@ def as_parfile_line(self, format="pint"):
)
# change ECL value to IERS2003 for TEMPO2
line = "%-15s %25s" % (name, "IERS2003")
- elif self.name == "NHARMS" and not (format.lower() == "pint"):
+ elif self.name == "NHARMS" and format.lower() != "pint":
# convert NHARMS value to int
line = "%-15s %25d" % (name, self.value)
elif self.name == "KIN" and format.lower() == "tempo":
# convert from DT92 convention to IAU
line = "%-15s %25s" % (name, self.str_quantity(180 * u.deg - self.quantity))
log.warning(
- f"Changing KIN from DT92 convention to IAU: this will not be readable by PINT"
+ "Changing KIN from DT92 convention to IAU: this will not be readable by PINT"
)
elif self.name == "KOM" and format.lower() == "tempo":
# convert from DT92 convention to IAU
line = "%-15s %25s" % (name, self.str_quantity(90 * u.deg - self.quantity))
log.warning(
- f"Changing KOM from DT92 convention to IAU: this will not be readable by PINT"
+ "Changing KOM from DT92 convention to IAU: this will not be readable by PINT"
)
- elif self.name == "DMDATA" and not format.lower() == "pint":
+ elif self.name == "DMDATA" and format.lower() != "pint":
line = "%-15s %d" % (self.name, int(self.value))
if self.uncertainty is not None:
@@ -521,7 +514,7 @@ def as_parfile_line(self, format="pint"):
if self.name == "T2CMETHOD" and format.lower() == "tempo2":
# comment out T2CMETHOD for TEMPO2
- line = "#" + line
+ line = f"#{line}"
return line + "\n"
def from_parfile_line(self, line):
@@ -568,10 +561,10 @@ def from_parfile_line(self, line):
try:
str2longdouble(k[2])
ucty = k[2]
- except ValueError:
+ except ValueError as e:
errmsg = f"Unidentified string '{k[2]}' in"
- errmsg += f" parfile line " + " ".join(k)
- raise ValueError(errmsg)
+ errmsg += " parfile line " + " ".join(k)
+ raise ValueError(errmsg) from e
if len(k) >= 4:
ucty = k[3]
@@ -754,15 +747,13 @@ def _set_quantity(self, val):
except AttributeError:
# This will happen if the input value did not have units
num_value = setfunc_no_unit(val)
- if self.unit_scale:
- # For some parameters, if the value is above a threshold, it is assumed to be in units of scale_factor
- # e.g. "PBDOT 7.2" is interpreted as "PBDOT 7.2E-12", since the scale_factor is 1E-12 and the scale_threshold is 1E-7
- if np.abs(num_value) > np.abs(self.scale_threshold):
- log.info(
- "Parameter %s's value will be scaled by %s"
- % (self.name, str(self.scale_factor))
- )
- num_value *= self.scale_factor
+ # For some parameters, if the value is above a threshold, it is assumed to be in units of scale_factor
+ # e.g. "PBDOT 7.2" is interpreted as "PBDOT 7.2E-12", since the scale_factor is 1E-12 and the scale_threshold is 1E-7
+ if self.unit_scale and np.abs(num_value) > np.abs(self.scale_threshold):
+ log.info(
+ f"Parameter {self.name}'s value will be scaled by {str(self.scale_factor)}"
+ )
+ num_value *= self.scale_factor
result = num_value * self.units
return result
@@ -773,22 +764,64 @@ def _set_uncertainty(self, val):
def str_quantity(self, quan):
"""Quantity as a string (for floating-point values)."""
v = quan.to(self.units).value
- if self._long_double:
- if not isinstance(v, np.longdouble):
- raise ValueError(
- "Parameter is supposed to contain long double values but contains a float"
- )
+ if self._long_double and not isinstance(v, np.longdouble):
+ raise ValueError(
+ "Parameter is supposed to contain long double values but contains a float"
+ )
return str(v)
def _get_value(self, quan):
"""Convert to appropriate units and extract value."""
if quan is None:
return None
- elif isinstance(quan, float) or isinstance(quan, np.longdouble):
+ elif isinstance(quan, (float, np.longdouble)):
return quan
else:
return quan.to(self.units).value
+ def as_ufloat(self, units=None):
+ """Return the parameter as a :class:`uncertainties.ufloat`
+
+ Will cast to the specified units, or the default
+ If the uncertainty is not set will be returned as 0
+
+ Parameters
+ ----------
+ units : astropy.units.core.Unit, optional
+ Units to cast the value
+
+ Returns
+ -------
+ uncertainties.ufloat
+
+ Notes
+ -----
+ Currently :class:`~uncertainties.ufloat` does not support double precision values,
+ so some precision may be lost.
+ """
+ if units is None:
+ units = self.units
+ value = self.quantity.to_value(units) if self.quantity is not None else 0
+ error = self.uncertainty.to_value(units) if self.uncertainty is not None else 0
+ return ufloat(value, error)
+
+ def from_ufloat(self, value, units=None):
+ """Set the parameter from the value of a :class:`uncertainties.ufloat`
+
+ Will cast to the specified units, or the default
+ If the uncertainty is 0 it will be set to ``None``
+
+ Parameters
+ ----------
+ value : uncertainties.ufloat
+ units : astropy.units.core.Unit, optional
+ Units to cast the value
+ """
+ if units is None:
+ units = self.units
+ self.quantity = value.n * units
+ self.uncertainty = value.s * units if value.s > 0 else None
+
class strParameter(Parameter):
"""String-valued parameter.
@@ -948,15 +981,14 @@ def _set_quantity(self, val):
if isinstance(val, str):
try:
ival = int(val)
- except ValueError:
+ except ValueError as e:
fval = float(val)
ival = int(fval)
if ival != fval and abs(fval) < 2**52:
raise ValueError(
f"Value {val} does not appear to be an integer "
f"but parameter {self.name} stores only integers."
- )
- return ival
+ ) from e
else:
ival = int(val)
fval = float(val)
@@ -965,7 +997,8 @@ def _set_quantity(self, val):
f"Value {val} does not appear to be an integer "
f"but parameter {self.name} stores only integers."
)
- return ival
+
+ return ival
class MJDParameter(Parameter):
@@ -1095,7 +1128,7 @@ def _set_quantity(self, val):
result = val
else:
raise ValueError(
- "MJD parameter can not accept " + type(val).__name__ + "format."
+ f"MJD parameter can not accept {type(val).__name__}format."
)
return result
@@ -1113,6 +1146,23 @@ def _set_uncertainty(self, val):
def _print_uncertainty(self, uncertainty):
return str(self.uncertainty_value)
+ def as_ufloats(self):
+ """Return the parameter as a pair of :class:`uncertainties.ufloat`
+ values representing the integer and fractional Julian dates.
+ The uncertainty is carried by the latter.
+
+ If the uncertainty is not set will be returned as 0
+
+ Returns
+ -------
+ uncertainties.ufloat
+ uncertainties.ufloat
+ """
+ value1 = self.quantity.jd1 if self.quantity is not None else 0
+ value2 = self.quantity.jd2 if self.quantity is not None else 0
+ error = self.uncertainty.to_value(u.d) if self.uncertainty is not None else 0
+ return ufloat(value1, 0), ufloat(value2, error)
+
class AngleParameter(Parameter):
"""Parameter in angle units.
@@ -1170,7 +1220,7 @@ def __init__(
}
# Check unit format
if units.lower() not in self.unit_identifier.keys():
- raise ValueError("Unidentified unit " + units)
+ raise ValueError(f"Unidentified unit {units}")
self.unitsuffix = self.unit_identifier[units.lower()][1]
self.value_type = Angle
@@ -1209,7 +1259,7 @@ def _set_quantity(self, val):
result = Angle(val.to(self.units))
else:
raise ValueError(
- "Angle parameter can not accept " + type(val).__name__ + "format."
+ f"Angle parameter can not accept {type(val).__name__}format."
)
return result
@@ -1225,7 +1275,7 @@ def _set_uncertainty(self, val):
result = Angle(val.to(self.unit_identifier[self._str_unit.lower()][2]))
else:
raise ValueError(
- "Angle parameter can not accept " + type(val).__name__ + "format."
+ f"Angle parameter can not accept {type(val).__name__}format."
)
return result
@@ -1238,14 +1288,14 @@ def str_quantity(self, quan):
def _print_uncertainty(self, unc):
"""This is a function for printing out the uncertainty"""
- if ":" in self._str_unit:
- angle_arcsec = unc.to(u.arcsec)
- if self.units == u.hourangle:
- # Triditionaly hourangle uncertainty is in hourangle seconds
- angle_arcsec /= 15.0
- return angle_arcsec.to_string(decimal=True, precision=20)
- else:
+ if ":" not in self._str_unit:
return unc.to_string(decimal=True, precision=20)
+ angle_arcsec = unc.to(u.arcsec)
+
+ if self.units == u.hourangle:
+ # Traditionally, hourangle uncertainty is in hourangle seconds
+ angle_arcsec /= 15.0
+ return angle_arcsec.to_string(decimal=True, precision=20)
class prefixParameter:
@@ -1335,8 +1385,8 @@ def __init__(
self.parameter_type = parameter_type
try:
self.param_class = self.type_mapping[self.parameter_type.lower()]
- except KeyError:
- raise ValueError("Unknown parameter type '" + parameter_type + "' ")
+ except KeyError as e:
+ raise ValueError(f"Unknown parameter type '{parameter_type}' ") from e
# Set up other attributes in the wrapper class
self.unit_template = unit_template
@@ -1356,9 +1406,7 @@ def __init__(
real_description = self.description_template(self.index)
else:
real_description = input_description
- aliases = []
- for pa in self.prefix_aliases:
- aliases.append(pa + self.idxfmt)
+ aliases = [pa + self.idxfmt for pa in self.prefix_aliases]
self.long_double = long_double
# initiate parameter class
self.param_comp = self.param_class(
@@ -1517,27 +1565,45 @@ def new_param(self, index, inheritfrozen=False):
A prefixed parameter with the same type of instance.
"""
- new_name = self.prefix + format(index, "0" + str(len(self.idxfmt)))
- kws = dict()
- for key in [
- "units",
- "unit_template",
- "description",
- "description_template",
- "frozen",
- "continuous",
- "prefix_aliases",
- "long_double",
- "time_scale",
- "parameter_type",
- ]:
- if hasattr(self, key):
- if (key == "frozen") and not (inheritfrozen):
- continue
- kws[key] = getattr(self, key)
+ new_name = self.prefix + format(index, f"0{len(self.idxfmt)}")
+ kws = {
+ key: getattr(self, key)
+ for key in [
+ "units",
+ "unit_template",
+ "description",
+ "description_template",
+ "frozen",
+ "continuous",
+ "prefix_aliases",
+ "long_double",
+ "time_scale",
+ "parameter_type",
+ ]
+ if hasattr(self, key) and (key != "frozen" or inheritfrozen)
+ }
+ return prefixParameter(name=new_name, **kws)
+
+ def as_ufloat(self, units=None):
+ """Return the parameter as a :class:`uncertainties.ufloat`
+
+ Will cast to the specified units, or the default
+ If the uncertainty is not set will be returned as 0
- newpfx = prefixParameter(name=new_name, **kws)
- return newpfx
+ Parameters
+ ----------
+ units : astropy.units.core.Unit, optional
+ Units to cast the value
+
+ Returns
+ -------
+ uncertainties.ufloat
+ """
+ if units is None:
+ units = self.units
+ value = self.quantity.to_value(units) if self.quantity is not None else 0
+ error = self.uncertainty.to_value(units) if self.uncertainty is not None else 0
+ return ufloat(value, error)
class maskParameter(floatParameter):
@@ -1628,19 +1694,18 @@ def __init__(
# Check key and key value
key_value_parser = str
if key is not None:
- if key.lower() in self.key_identifier.keys():
+ if key.lower() in self.key_identifier:
key_info = self.key_identifier[key.lower()]
if len(key_value) != key_info[1]:
errmsg = f"key {key} takes {key_info[1]} element(s)."
raise ValueError(errmsg)
key_value_parser = key_info[0]
- else:
- if not key.startswith("-"):
- raise ValueError(
- "A key to a TOA flag requires a leading '-'."
- " Legal keywords that don't require a leading '-' "
- "are MJD, FREQ, NAME, TEL."
- )
+ elif not key.startswith("-"):
+ raise ValueError(
+ "A key to a TOA flag requires a leading '-'."
+ " Legal keywords that don't require a leading '-' "
+ "are MJD, FREQ, NAME, TEL."
+ )
self.key = key
self.key_value = [
key_value_parser(k) for k in key_value
@@ -1650,10 +1715,7 @@ def __init__(
name_param = name + str(index)
self.origin_name = name
self.prefix = self.origin_name
- # Make aliases with index.
- idx_aliases = []
- for al in aliases:
- idx_aliases.append(al + str(self.index))
+ idx_aliases = [al + str(self.index) for al in aliases]
self.prefix_aliases = aliases
super().__init__(
name=name_param,
@@ -1675,21 +1737,22 @@ def __init__(
self._parfile_name = self.origin_name
def __repr__(self):
- out = self.__class__.__name__ + "(" + self.name
+ out = f"{self.__class__.__name__}({self.name}"
if self.key is not None:
- out += " " + self.key
+ out += f" {self.key}"
if self.key_value is not None:
for kv in self.key_value:
- out += " " + str(kv)
+ out += f" {str(kv)}"
if self.quantity is not None:
- out += " " + self.str_quantity(self.quantity)
+ out += f" {self.str_quantity(self.quantity)}"
else:
- out += " " + "UNSET"
+ out += " UNSET"
return out
+
if self.uncertainty is not None and isinstance(self.value, numbers.Number):
- out += " +/- " + str(self.uncertainty.to(self.units))
+ out += f" +/- {str(self.uncertainty.to(self.units))}"
if self.units is not None:
- out += " (" + str(self.units) + ")"
+ out += f" ({str(self.units)})"
out += ")"
return out
@@ -1738,10 +1801,10 @@ def from_parfile_line(self, line):
try:
self.key = k[1]
- except IndexError:
+ except IndexError as e:
raise ValueError(
"{}: No key found on timfile line {!r}".format(self.name, line)
- )
+ ) from e
key_value_info = self.key_identifier.get(self.key.lower(), (str, 1))
len_key_v = key_value_info[1]
@@ -1781,10 +1844,10 @@ def from_parfile_line(self, line):
try:
str2longdouble(k[3 + len_key_v])
ucty = k[3 + len_key_v]
- except ValueError:
- errmsg = "Unidentified string " + k[3 + len_key_v] + " in"
- errmsg += " parfile line " + k
- raise ValueError(errmsg)
+ except ValueError as exc:
+ errmsg = f"Unidentified string {k[3 + len_key_v]} in"
+ errmsg += f" parfile line {k}"
+ raise ValueError(errmsg) from exc
if len(k) >= 5 + len_key_v:
ucty = k[4 + len_key_v]
@@ -1795,21 +1858,19 @@ def as_parfile_line(self, format="pint"):
assert (
format.lower() in _parfile_formats
), "parfile format must be one of %s" % ", ".join(
- ['"%s"' % x for x in _parfile_formats]
+ [f'"{x}"' for x in _parfile_formats]
)
if self.quantity is None:
return ""
- if self.use_alias is None:
- name = self.origin_name
- else:
- name = self.use_alias
+
+ name = self.origin_name if self.use_alias is None else self.use_alias
# special cases for parameter names that change depending on format
- if name == "EFAC" and not (format.lower() == "pint"):
+ if name == "EFAC" and format.lower() != "pint":
# change to T2EFAC for TEMPO/TEMPO2
name = "T2EFAC"
- elif name == "EQUAD" and not (format.lower() == "pint"):
+ elif name == "EQUAD" and format.lower() != "pint":
# change to T2EQUAD for TEMPO/TEMPO2
name = "T2EQUAD"
@@ -1830,16 +1891,8 @@ def as_parfile_line(self, format="pint"):
def new_param(self, index, copy_all=False):
"""Create a new but same style mask parameter"""
- if not copy_all:
- new_mask_param = maskParameter(
- name=self.origin_name,
- index=index,
- long_double=self.long_double,
- units=self.units,
- aliases=self.prefix_aliases,
- )
- else:
- new_mask_param = maskParameter(
+ return (
+ maskParameter(
name=self.origin_name,
index=index,
key=self.key,
@@ -1853,7 +1906,15 @@ def new_param(self, index, copy_all=False):
continuous=self.continuous,
aliases=self.prefix_aliases,
)
- return new_mask_param
+ if copy_all
+ else maskParameter(
+ name=self.origin_name,
+ index=index,
+ long_double=self.long_double,
+ units=self.units,
+ aliases=self.prefix_aliases,
+ )
+ )
def select_toa_mask(self, toas):
"""Select the toas that match the mask.
@@ -1867,7 +1928,6 @@ def select_toa_mask(self, toas):
array
An array of TOA indices selected by the mask.
"""
- column_match = {"mjd": "mjd_float", "freq": "freq", "tel": "obs"}
if len(self.key_value) == 1:
if not hasattr(self, "toa_selector"):
self.toa_selector = TOASelect(is_range=False, use_hash=True)
@@ -1880,18 +1940,15 @@ def select_toa_mask(self, toas):
return np.array([], dtype=int)
else:
raise ValueError(
- "Parameter %s has more key values than "
- "expected.(Expect 1 or 2 key values)" % self.name
+ f"Parameter {self.name} has more key values than expected.(Expect 1 or 2 key values)"
)
# get the table columns
# TODO Right now it is only supports mjd, freq, tel, and flagkeys,
# We need to consider some more complicated situation
- if self.key.startswith("-"):
- key = self.key[1::]
- else:
- key = self.key
+ key = self.key[1::] if self.key.startswith("-") else self.key
tbl = toas.table
+ column_match = {"mjd": "mjd_float", "freq": "freq", "tel": "obs"}
if (
self.key.lower() not in column_match
): # This only works for the one with flags.
@@ -1924,13 +1981,11 @@ def compare_key_value(self, other_param):
ValueError:
If the parameter to compare does not have 'key' or 'key_value'.
"""
- if not (hasattr(other_param, "key") or hasattr(other_param, "key_value")):
+ if not hasattr(other_param, "key") and not hasattr(other_param, "key_value"):
raise ValueError("Parameter to compare does not have `key` or `key_value`.")
if self.key != other_param.key:
return False
- if self.key_value != other_param.key_value:
- return False
- return True
+ return self.key_value == other_param.key_value
class pairParameter(floatParameter):
@@ -1998,9 +2053,8 @@ def __init__(
def name_matches(self, name):
if super().name_matches(name):
return True
- else:
- name_idx = name + str(self.index)
- return super().name_matches(name_idx)
+ name_idx = name + str(self.index)
+ return super().name_matches(name_idx)
def from_parfile_line(self, line):
"""Read mask parameter line (e.g. JUMP).
@@ -2033,10 +2087,7 @@ def as_parfile_line(self, format="pint"):
quantity = self.quantity
if self.quantity is None:
return ""
- if self.use_alias is None:
- name = self.name
- else:
- name = self.use_alias
+ name = self.name if self.use_alias is None else self.use_alias
line = "%-15s " % name
line += "%25s" % self.str_quantity(quantity[0])
line += " %25s" % self.str_quantity(quantity[1])
@@ -2045,14 +2096,13 @@ def as_parfile_line(self, format="pint"):
def new_param(self, index):
"""Create a new but same style mask parameter."""
- new_pair_param = pairParameter(
+ return pairParameter(
name=self.origin_name,
index=index,
long_double=self.long_double,
units=self.units,
aliases=self.prefix_aliases,
)
- return new_pair_param
def _set_quantity(self, vals):
vals = [floatParameter._set_quantity(self, val) for val in vals]
@@ -2067,10 +2117,7 @@ def value(self):
This value will associate with parameter default value, which is .units attribute.
"""
- if self._quantity is None:
- return None
- else:
- return self._get_value(self._quantity)
+ return None if self._quantity is None else self._get_value(self._quantity)
@value.setter
def value(self, val):
@@ -2098,26 +2145,22 @@ def str_quantity(self, quan):
except AttributeError:
# Not a quantity, let's hope it's a list of length two?
if len(quan) != 2:
- raise ValueError("Don't know how to print this as a pair: %s" % (quan,))
+ raise ValueError(f"Don't know how to print this as a pair: {quan}")
v0 = quan[0].to(self.units).value
v1 = quan[1].to(self.units).value
if self._long_double:
if not isinstance(v0, np.longdouble):
raise TypeError(
- "Parameter {} is supposed to contain long doubles but contains a float".format(
- self
- )
+ f"Parameter {self} is supposed to contain long doubles but contains a float"
)
if not isinstance(v1, np.longdouble):
raise TypeError(
- "Parameter {} is supposed to contain long doubles but contains a float".format(
- self
- )
+ f"Parameter {self} is supposed to contain long doubles but contains a float"
)
quan0 = str(v0)
quan1 = str(v1)
- return quan0 + " " + quan1
+ return f"{quan0} {quan1}"
class funcParameter(floatParameter):
@@ -2271,7 +2314,7 @@ def _get_parentage(self, max_level=2):
self._parentlevel = []
for i, p in enumerate(self._params):
parent = self._parent
- for level in range(max_level):
+ for _ in range(max_level):
if hasattr(parent, p):
self._parentlevel.append(parent)
break
diff --git a/src/pint/models/phase_offset.py b/src/pint/models/phase_offset.py
new file mode 100644
index 000000000..3793b5bca
--- /dev/null
+++ b/src/pint/models/phase_offset.py
@@ -0,0 +1,52 @@
+"""Explicit phase offset"""
+
+from pint.models.timing_model import PhaseComponent
+from pint.models.parameter import floatParameter
+from astropy import units as u
+import numpy as np
+
+
+class PhaseOffset(PhaseComponent):
+ """Explicit pulse phase offset between physical TOAs and the TZR TOA.
+ See `examples/phase_offset_example.py` for example usage.
+
+ Parameters supported:
+
+ .. paramtable::
+ :class: pint.models.phase_offset.PhaseOffset
+ """
+
+ register = True
+ category = "phase_offset"
+
+ def __init__(self):
+ super().__init__()
+ self.add_param(
+ floatParameter(
+ name="PHOFF",
+ value=0.0,
+ units="",
+ description="Overall phase offset between physical TOAs and the TZR TOA.",
+ )
+ )
+ self.phase_funcs_component += [self.offset_phase]
+ self.register_deriv_funcs(self.d_offset_phase_d_PHOFF, "PHOFF")
+
+ def offset_phase(self, toas, delay):
+ """An overall phase offset between physical TOAs and the TZR TOA."""
+
+ return (
+ (np.zeros(len(toas)) * self.PHOFF.quantity).to(u.dimensionless_unscaled)
+ if toas.tzr
+ else (-np.ones(len(toas)) * self.PHOFF.quantity).to(
+ u.dimensionless_unscaled
+ )
+ )
+
+ def d_offset_phase_d_PHOFF(self, toas, param, delay):
+ """Derivative of the pulse phase w.r.t. PHOFF"""
+ return (
+ np.zeros(len(toas)) * u.Unit("")
+ if toas.tzr
+ else -np.ones(len(toas)) * u.Unit("")
+ )
diff --git a/src/pint/models/piecewise.py b/src/pint/models/piecewise.py
index e69637045..9e78cd2ff 100644
--- a/src/pint/models/piecewise.py
+++ b/src/pint/models/piecewise.py
@@ -231,10 +231,7 @@ def d_phase_d_F(self, toas, param, delay):
par = getattr(self, param)
unit = par.units
pn, idxf, idxv = split_prefixed_name(param)
- if param.startswith("PWF"):
- order = split_prefixed_name(param[:4])[2] + 1
- else:
- order = 0
+ order = split_prefixed_name(param[:4])[2] + 1 if param.startswith("PWF") else 0
# order = idxv + 1
fterms = self.get_spin_terms(idxv)
# make the chosen fterms 1 others 0
diff --git a/src/pint/models/priors.py b/src/pint/models/priors.py
index 3ad17f0a9..e39c9237c 100644
--- a/src/pint/models/priors.py
+++ b/src/pint/models/priors.py
@@ -52,7 +52,6 @@ class Prior:
def __init__(self, rv):
self._rv = rv
- pass
def pdf(self, value):
# The astype() calls prevent unsafe cast messages
@@ -113,14 +112,13 @@ def UniformBoundedRV(lower_bound, upper_bound):
Returns a frozen rv_continuous instance with a uniform probability
inside the range lower_bound to upper_bound and 0.0 outside
"""
- uu = uniform(lower_bound, (upper_bound - lower_bound))
- return uu
+ return uniform(lower_bound, (upper_bound - lower_bound))
class GaussianRV_gen(rv_continuous):
r"""A Gaussian prior between two bounds.
If you just want a gaussian, use scipy.stats.norm
- This version is for generating bounded gaussians
+ This version is for generating bounded Gaussians
Parameters
----------
@@ -132,8 +130,7 @@ class GaussianRV_gen(rv_continuous):
"""
def _pdf(self, x):
- ret = np.exp(-(x**2) / 2) / np.sqrt(2 * np.pi)
- return ret
+ return np.exp(-(x**2) / 2) / np.sqrt(2 * np.pi)
def GaussianBoundedRV(loc=0.0, scale=1.0, lower_bound=-np.inf, upper_bound=np.inf):
@@ -153,5 +150,4 @@ def GaussianBoundedRV(loc=0.0, scale=1.0, lower_bound=-np.inf, upper_bound=np.in
ymin = (lower_bound - loc) / scale
ymax = (upper_bound - loc) / scale
n = GaussianRV_gen(name="bounded_gaussian", a=ymin, b=ymax)
- nn = n(loc=loc, scale=scale)
- return nn
+ return n(loc=loc, scale=scale)
diff --git a/src/pint/models/pulsar_binary.py b/src/pint/models/pulsar_binary.py
index c143f2eb8..5e32079ad 100644
--- a/src/pint/models/pulsar_binary.py
+++ b/src/pint/models/pulsar_binary.py
@@ -8,7 +8,6 @@
import astropy.units as u
import contextlib
import numpy as np
-from astropy.time import Time
from astropy.coordinates import SkyCoord
from loguru import logger as log
@@ -26,7 +25,7 @@
TimingModelError,
UnknownParameter,
)
-from pint.utils import taylor_horner_deriv
+from pint.utils import taylor_horner_deriv, parse_time
from pint.pulsar_ecliptic import PulsarEcliptic
@@ -404,7 +403,12 @@ def update_binary_object(self, toas, acc_delay=None):
except UnknownParameter:
if par in self.internal_params:
pint_bin_name = par
+ else:
+ raise UnknownParameter(
+ f"Unable to find {par} in the parent model"
+ )
binObjpar = getattr(self._parent, pint_bin_name)
+
# make sure we aren't passing along derived parameters to the binary instance
if isinstance(binObjpar, funcParameter):
continue
@@ -482,10 +486,7 @@ def change_binary_epoch(self, new_epoch):
new_epoch: float MJD (in TDB) or `astropy.Time` object
The new epoch value.
"""
- if isinstance(new_epoch, Time):
- new_epoch = Time(new_epoch, scale="tdb", precision=9)
- else:
- new_epoch = Time(new_epoch, scale="tdb", format="mjd", precision=9)
+ new_epoch = parse_time(new_epoch, scale="tdb", precision=9)
# Get PB and PBDOT from model
if self.PB.quantity is not None and not isinstance(self.PB, funcParameter):
@@ -537,3 +538,64 @@ def change_binary_epoch(self, new_epoch):
self.OM.quantity = self.OM.quantity + dOM
dA1 = self.A1DOT.quantity * dt_integer_orbits
self.A1.quantity = self.A1.quantity + dA1
+
+ def pb(self, t=None):
+ """Return binary period and uncertainty (optionally evaluated at different times) regardless of binary model
+
+ Parameters
+ ----------
+ t : astropy.time.Time, astropy.units.Quantity, numpy.ndarray, float, int, str, optional
+ Time(s) to evaluate period
+
+ Returns
+ -------
+ astropy.units.Quantity :
+ Binary period
+ astropy.units.Quantity :
+ Binary period uncertainty
+
+ """
+ if self.binary_model_name.startswith("ELL1"):
+ t0 = self.TASC.quantity
+ else:
+ t0 = self.T0.quantity
+ t = t0 if t is None else parse_time(t)
+ if self.PB.quantity is not None:
+ if self.PBDOT.quantity is None and (
+ not hasattr(self, "XPBDOT")
+ or getattr(self, "XPBDOT").quantity is not None
+ ):
+ return self.PB.quantity, self.PB.uncertainty
+ pb = self.PB.as_ufloat(u.d)
+ if self.PBDOT.quantity is not None:
+ pbdot = self.PBDOT.as_ufloat(u.s / u.s)
+ if hasattr(self, "XPBDOT") and self.XPBDOT.quantity is not None:
+ pbdot += self.XPBDOT.as_ufloat(u.s / u.s)
+ pnew = pb + pbdot * (t - t0).jd
+ if not isinstance(pnew, np.ndarray):
+ return pnew.n * u.d, pnew.s * u.d if pnew.s > 0 else None
+ import uncertainties.unumpy
+
+ return (
+ uncertainties.unumpy.nominal_values(pnew) * u.d,
+ uncertainties.unumpy.std_devs(pnew) * u.d,
+ )
+
+ elif self.FB0.quantity is not None:
+ # assume FB terms
+ dt = (t - t0).sec
+ coeffs = []
+ unit = u.Hz
+ for p in self.get_prefix_mapping_component("FB").values():
+ coeffs.append(getattr(self, p).as_ufloat(unit))
+ unit /= u.s
+ pnew = 1 / taylor_horner_deriv(dt, coeffs, deriv_order=0)
+ if not isinstance(pnew, np.ndarray):
+ return pnew.n * u.s, pnew.s * u.s if pnew.s > 0 else None
+ import uncertainties.unumpy
+
+ return (
+ uncertainties.unumpy.nominal_values(pnew) * u.s,
+ uncertainties.unumpy.std_devs(pnew) * u.s,
+ )
+ raise AttributeError("Neither PB nor FB0 is present in the timing model.")
diff --git a/src/pint/models/solar_system_shapiro.py b/src/pint/models/solar_system_shapiro.py
index c06d5a366..1cd850065 100644
--- a/src/pint/models/solar_system_shapiro.py
+++ b/src/pint/models/solar_system_shapiro.py
@@ -111,13 +111,13 @@ def solar_system_shapiro_delay(self, toas, acc_delay=None):
if self.PLANET_SHAPIRO.value:
for pl in ("jupiter", "saturn", "venus", "uranus", "neptune"):
delay[grp] += self.ss_obj_shapiro_delay(
- tbl[grp]["obs_" + pl + "_pos"],
+ tbl[grp][f"obs_{pl}_pos"],
psr_dir,
self._ss_mass_sec[pl],
)
- except KeyError:
+ except KeyError as e:
raise KeyError(
"Planet positions not found when trying to compute Solar System Shapiro delay. "
"Make sure that you include `planets=True` in your `get_TOAs()` call, or use `get_model_and_toas()`."
- )
+ ) from e
return delay * u.second
diff --git a/src/pint/models/solar_wind_dispersion.py b/src/pint/models/solar_wind_dispersion.py
index f9c9734ac..0f220e2e6 100644
--- a/src/pint/models/solar_wind_dispersion.py
+++ b/src/pint/models/solar_wind_dispersion.py
@@ -256,7 +256,13 @@ def _get_reference_time(
return default
-class SolarWindDispersion(Dispersion):
+class SolarWindDispersionBase(Dispersion):
+ """Abstract base class for solar wind dispersion components."""
+
+ pass
+
+
+class SolarWindDispersion(SolarWindDispersionBase):
"""Dispersion due to the solar wind (basic model).
The model is a simple spherically-symmetric model that is fit
@@ -368,13 +374,12 @@ def solar_wind_dm(self, toas):
"""
if self.NE_SW.value == 0:
return np.zeros(len(toas)) * u.pc / u.cm**3
- if self.SWM.value == 0 or self.SWM.value == 1:
- solar_wind_geometry = self.solar_wind_geometry(toas)
- solar_wind_dm = self.NE_SW.quantity * solar_wind_geometry
- else:
+ if self.SWM.value not in [0, 1]:
raise NotImplementedError(
- "Solar Dispersion Delay not implemented for SWM %d" % self.SWM.value
+ f"Solar Dispersion Delay not implemented for SWM {self.SWM.value}"
)
+ solar_wind_geometry = self.solar_wind_geometry(toas)
+ solar_wind_dm = self.NE_SW.quantity * solar_wind_geometry
return solar_wind_dm.to(u.pc / u.cm**3)
def solar_wind_delay(self, toas, acc_delay=None):
@@ -385,7 +390,7 @@ def solar_wind_delay(self, toas, acc_delay=None):
def d_dm_d_ne_sw(self, toas, param_name, acc_delay=None):
"""Derivative of of DM wrt the solar wind ne amplitude."""
- if self.SWM.value == 0 or self.SWM.value == 1:
+ if self.SWM.value in [0, 1]:
solar_wind_geometry = self.solar_wind_geometry(toas)
else:
raise NotImplementedError(
@@ -514,7 +519,7 @@ def get_min_dm(self):
)
-class SolarWindDispersionX(Dispersion):
+class SolarWindDispersionX(SolarWindDispersionBase):
"""This class provides a SWX model - multiple Solar Wind segments.
This model lets the user specify time ranges and fit for a different
@@ -713,8 +718,7 @@ def add_swx_range(
if int(index) in self.get_prefix_mapping_component("SWXDM_"):
raise ValueError(
- "Index '%s' is already in use in this model. Please choose another."
- % index
+ f"Index '{index}' is already in use in this model. Please choose another."
)
if isinstance(swxdm, u.quantity.Quantity):
swxdm = swxdm.to_value(u.pc / u.cm**3)
@@ -730,7 +734,7 @@ def add_swx_range(
swxp = swxp.value
self.add_param(
prefixParameter(
- name="SWXDM_" + i,
+ name=f"SWXDM_{i}",
units="pc cm^-3",
value=swxdm,
description="Max Solar Wind DM",
@@ -740,7 +744,7 @@ def add_swx_range(
)
self.add_param(
prefixParameter(
- name="SWXP_" + i,
+ name=f"SWXP_{i}",
value=swxp,
description="Solar wind power-law index",
parameter_type="float",
@@ -749,7 +753,7 @@ def add_swx_range(
)
self.add_param(
prefixParameter(
- name="SWXR1_" + i,
+ name=f"SWXR1_{i}",
units="MJD",
description="Beginning of SWX interval",
parameter_type="MJD",
@@ -759,7 +763,7 @@ def add_swx_range(
)
self.add_param(
prefixParameter(
- name="SWXR2_" + i,
+ name=f"SWXR2_{i}",
units="MJD",
description="End of SWX interval",
parameter_type="MJD",
@@ -781,11 +785,7 @@ def remove_swx_range(self, index):
Number or list/array of numbers corresponding to SWX indices to be removed from model.
"""
- if (
- isinstance(index, int)
- or isinstance(index, float)
- or isinstance(index, np.int64)
- ):
+ if isinstance(index, (int, float, np.int64)):
indices = [index]
elif isinstance(index, (list, np.ndarray)):
indices = index
@@ -807,10 +807,7 @@ def get_indices(self):
inds : np.ndarray
Array of SWX indices in model.
"""
- inds = []
- for p in self.params:
- if "SWXDM_" in p:
- inds.append(int(p.split("_")[-1]))
+ inds = [int(p.split("_")[-1]) for p in self.params if "SWXDM_" in p]
return np.array(inds)
def setup(self):
@@ -821,7 +818,7 @@ def setup(self):
if prefix_par.startswith("SWXDM_"):
# check to make sure power-law index is present
# if not, put in default
- p_name = "SWXP_" + pint.utils.split_prefixed_name(prefix_par)[1]
+ p_name = f"SWXP_{pint.utils.split_prefixed_name(prefix_par)[1]}"
if not hasattr(self, p_name):
self.add_param(
prefixParameter(
@@ -924,8 +921,8 @@ def swx_dm(self, toas):
# Get SWX delays
dm = np.zeros(len(tbl)) * self._parent.DM.units
for k, v in select_idx.items():
- dmmax = getattr(self, k).quantity
if len(v) > 0:
+ dmmax = getattr(self, k).quantity
dm[v] += (
dmmax
* (
@@ -1001,7 +998,7 @@ def d_dm_d_swxp(self, toas, param_name, acc_delay=None):
r1 = getattr(self, SWXR1_mapping[swxp_index]).quantity
r2 = getattr(self, SWXR2_mapping[swxp_index]).quantity
- swx_name = "SWXDM_" + pint.utils.split_prefixed_name(param_name)[1]
+ swx_name = f"SWXDM_{pint.utils.split_prefixed_name(param_name)[1]}"
condition = {swx_name: (r1.mjd, r2.mjd)}
select_idx = self.swx_toas_selector.get_select_index(
condition, tbl["mjd_float"]
@@ -1151,7 +1148,7 @@ def set_ne_sws(self, ne_sws):
sorted_list = sorted(SWXDM_mapping.keys())
if len(ne_sws) == 1:
ne_sws = ne_sws[0] * np.ones(len(sorted_list))
- if not len(sorted_list) == len(ne_sws):
+ if len(sorted_list) != len(ne_sws):
raise ValueError(
f"Length of input NE_SW values ({len(ne_sws)}) must match number of SWX segments ({len(sorted_list)})"
)
diff --git a/src/pint/models/spindown.py b/src/pint/models/spindown.py
index 04d4669a5..78bbba8cf 100644
--- a/src/pint/models/spindown.py
+++ b/src/pint/models/spindown.py
@@ -10,7 +10,13 @@
from pint.utils import split_prefixed_name, taylor_horner, taylor_horner_deriv
-class Spindown(PhaseComponent):
+class SpindownBase(PhaseComponent):
+ """An abstract base class to mark Spindown components."""
+
+ pass
+
+
+class Spindown(SpindownBase):
"""A simple timing model for an isolated pulsar.
This represents the pulsar's spin as a Taylor series,
@@ -91,11 +97,10 @@ def validate(self):
# Check continuity
self._parent.get_prefix_list("F", start_index=0)
# If F1 is set, we need PEPOCH
- if hasattr(self, "F1") and self.F1.value != 0.0:
- if self.PEPOCH.value is None:
- raise MissingParameter(
- "Spindown", "PEPOCH", "PEPOCH is required if F1 or higher are set"
- )
+ if hasattr(self, "F1") and self.F1.value != 0.0 and self.PEPOCH.value is None:
+ raise MissingParameter(
+ "Spindown", "PEPOCH", "PEPOCH is required if F1 or higher are set"
+ )
@property
def F_terms(self):
@@ -128,8 +133,7 @@ def get_dt(self, toas, delay):
phsepoch_ld = (tbl["tdb"][0] - delay[0]).tdb.mjd_long
else:
phsepoch_ld = self.PEPOCH.quantity.tdb.mjd_long
- dt = (tbl["tdbld"] - phsepoch_ld) * u.day - delay
- return dt
+ return (tbl["tdbld"] - phsepoch_ld) * u.day - delay
def spindown_phase(self, toas, delay):
"""Spindown phase function.
@@ -180,7 +184,7 @@ def change_pepoch(self, new_epoch, toas=None, delay=None):
fterms = [0.0 * u.Unit("")] + self.get_spin_terms()
# rescale the fterms
for n in range(len(fterms) - 1):
- f_par = getattr(self, "F{}".format(n))
+ f_par = getattr(self, f"F{n}")
f_par.value = taylor_horner_deriv(
dt.to(u.second), fterms, deriv_order=n + 1
)
diff --git a/src/pint/models/stand_alone_psr_binaries/BT_model.py b/src/pint/models/stand_alone_psr_binaries/BT_model.py
index 56cc4b2a7..2e8a121f1 100644
--- a/src/pint/models/stand_alone_psr_binaries/BT_model.py
+++ b/src/pint/models/stand_alone_psr_binaries/BT_model.py
@@ -215,33 +215,33 @@ def d_delayL2_d_T0(self):
def d_delayL1_d_par(self, par):
if par not in self.binary_params:
- errorMesg = par + " is not in binary parameter list."
+ errorMesg = f"{par} is not in binary parameter list."
raise ValueError(errorMesg)
par_obj = getattr(self, par)
- if hasattr(self, "d_delayL1_d_" + par):
- func = getattr(self, "d_delayL1_d_" + par)
- return func()
- else:
- if par in self.orbits_cls.orbit_params:
- return self.d_delayL1_d_E() * self.d_E_d_par(par)
- else:
- return np.zeros(len(self.t)) * u.second / par_obj.unit
+ if not hasattr(self, f"d_delayL1_d_{par}"):
+ return (
+ self.d_delayL1_d_E() * self.d_E_d_par(par)
+ if par in self.orbits_cls.orbit_params
+ else np.zeros(len(self.t)) * u.second / par_obj.unit
+ )
+ func = getattr(self, f"d_delayL1_d_{par}")
+ return func()
def d_delayL2_d_par(self, par):
if par not in self.binary_params:
- errorMesg = par + " is not in binary parameter list."
+ errorMesg = f"{par} is not in binary parameter list."
raise ValueError(errorMesg)
par_obj = getattr(self, par)
- if hasattr(self, "d_delayL2_d_" + par):
- func = getattr(self, "d_delayL2_d_" + par)
- return func()
- else:
- if par in self.orbits_cls.orbit_params:
- return self.d_delayL2_d_E() * self.d_E_d_par(par)
- else:
- return np.zeros(len(self.t)) * u.second / par_obj.unit
+ if not hasattr(self, f"d_delayL2_d_{par}"):
+ return (
+ self.d_delayL2_d_E() * self.d_E_d_par(par)
+ if par in self.orbits_cls.orbit_params
+ else np.zeros(len(self.t)) * u.second / par_obj.unit
+ )
+ func = getattr(self, f"d_delayL2_d_{par}")
+ return func()
def d_BTdelay_d_par(self, par):
return self.delayR() * (self.d_delayL1_d_par(par) + self.d_delayL2_d_par(par))
diff --git a/src/pint/models/stand_alone_psr_binaries/DDGR_model.py b/src/pint/models/stand_alone_psr_binaries/DDGR_model.py
new file mode 100644
index 000000000..85c40fd39
--- /dev/null
+++ b/src/pint/models/stand_alone_psr_binaries/DDGR_model.py
@@ -0,0 +1,709 @@
+"""The DDGR model - Damour and Deruelle with GR assumed"""
+import astropy.constants as c
+import astropy.units as u
+import numpy as np
+from loguru import logger as log
+
+from .DD_model import DDmodel
+
+
+@u.quantity_input(M1=u.Msun, M2=u.Msun, n=1 / u.d)
+def _solve_kepler(M1, M2, n, ARTOL=1e-10):
+ """Relativistic version of Kepler's third law, solved by iteration
+
+ Taylor & Weisberg (1989), Eqn. 15
+ In tempo, implemented as ``mass2dd`` (https://sourceforge.net/p/tempo/tempo/ci/master/tree/src/mass2dd.f)
+
+
+ Parameters
+ ----------
+ M1 : astropy.units.Quantity
+ Mass of pulsar
+ M2 : astropy.units.Quantity
+ Mass of companion
+ n : astropy.units.Quantity
+ orbital angular frequency
+ ARTOL : float
+ fractional tolerance for solution
+
+ Returns
+ -------
+ arr0 : astropy.units.Quantity
+ non-relativistic semi-major axis
+ arr : astropy.units.Quantity
+ relativstic semi-major axis
+ """
+ MTOT = M1 + M2
+ # initial NR value
+ arr0 = (c.G * MTOT / n**2) ** (1.0 / 3)
+ arr = arr0
+ arr_old = arr
+ arr = arr0 * (
+ 1 + (M1 * M2 / MTOT**2 - 9) * (c.G * MTOT / (2 * arr * c.c**2))
+ ) ** (2.0 / 3)
+ # iterate to get correct value
+ while np.fabs((arr - arr_old) / arr) > ARTOL:
+ arr_old = arr
+ ar = arr0 * (
+ 1 + (M1 * M2 / MTOT**2 - 9) * (c.G * MTOT / (2 * arr * c.c**2))
+ ) ** (2.0 / 3)
+
+ return arr0.decompose(), arr.decompose()
+
+
+class DDGRmodel(DDmodel):
+ """Damour and Deruelle model assuming GR to be correct
+
+ It supports all the parameters defined in :class:`pint.models.pulsar_binary.PulsarBinary`
+ and :class:`pint.models.binary_dd.BinaryDD` plus:
+
+ MTOT
+ Total mass
+ XPBDOT
+ Excess PBDOT beyond what GR predicts
+ XOMDOT
+ Excess OMDOT beyond what GR predicts
+
+ It also reads but ignores:
+
+ SINI
+ PBDOT
+ OMDOT
+ GAMMA
+ DR
+ DTH
+
+ Parameters supported:
+
+ .. paramtable::
+ :class: pint.models.binary_dd.BinaryDDGR
+
+ References
+ ----------
+ - Taylor and Weisberg (1989), ApJ, 345, 434 [tw89]_
+
+ .. [tw89] https://ui.adsabs.harvard.edu/abs/1989ApJ...345..434T/abstract
+ """
+
+ def __init__(self, t=None, input_params=None):
+ super().__init__()
+ self.binary_name = "DDS"
+ self.param_default_value.update(
+ {"MTOT": 2.8 * u.Msun, "XOMDOT": 0 * u.deg / u.yr, "XPBDOT": 0 * u.s / u.s}
+ )
+
+ # If any parameter has aliases, it should be updated
+ # self.param_aliases.update({})
+ self.binary_params = list(self.param_default_value.keys())
+ # Remove unused parameter SINI and others
+ for p in ["SINI", "PBDOT", "OMDOT", "GAMMA", "DR", "DTH"]:
+ del self.param_default_value[p]
+ self.set_param_values()
+ if input_params is not None:
+ self.update_input(param_dict=input_params)
+
+ def _updatePK(self, ARTOL=1e-10):
+ """Update measurable PK quantities from system parameters for DDGR model
+
+ Taylor & Weisberg (1989), Eqn. 15-25
+ In tempo, implemented as ``mass2dd`` (https://sourceforge.net/p/tempo/tempo/ci/master/tree/src/mass2dd.f)
+
+ Parameters
+ ----------
+ ARTOL : float
+ fractional tolerance for solution of relativistic Kepler equation (passed to :func:`_solve_kepler`)
+
+ """
+ # if not all of the required parameters have been set yet (e.g., while initializing)
+ # don't do anything
+ for p in ["PB", "ECC", "M2", "MTOT", "A1"]:
+ if not hasattr(self, p) or getattr(self, p).value is None:
+ return
+
+ # unclear if this should compute the PB in a different way
+ # since this could incorporate changes, but to determine those changes we need to run this function
+ PB = self.PB.to(u.s)
+ self._M1 = self.MTOT - self.M2
+ self._n = 2 * np.pi / PB
+ arr0, arr = _solve_kepler(self._M1, self.M2, self._n, ARTOL=ARTOL)
+ self._arr = arr
+ # pulsar component of semi-major axis
+ self._ar = self._arr * (self.M2 / self.MTOT)
+ # Taylor & Weisberg (1989), Eqn. 20
+ self._SINI = (self.a1() / self._ar).decompose()
+ # Taylor & Weisberg (1989), Eqn. 17
+ # use arr0 here following comments in tempo
+ self._GAMMA = (
+ self.ecc()
+ * c.G
+ * self.M2
+ * (self._M1 + 2 * self.M2)
+ / (self._n * c.c**2 * arr0 * self.MTOT)
+ ).to(u.s)
+ # Taylor & Weisberg (1989), Eqn. 18
+ self._PBDOT = (
+ (-192 * np.pi / (5 * c.c**5))
+ * (c.G * self._n) ** (5.0 / 3)
+ * self._M1
+ * self.M2
+ * self.MTOT ** (-1.0 / 3)
+ * self.fe
+ ).decompose()
+ # we calculate this here although we don't need it for DDGR
+ self._OMDOT = (
+ 3
+ * (self._n) ** (5.0 / 3)
+ * (1 / (1 - self.ecc() ** 2))
+ * (c.G * (self._M1 + self.M2) / c.c**3) ** (2.0 / 3)
+ ).to(u.deg / u.yr, equivalencies=u.dimensionless_angles())
+ # Taylor & Weisberg (1989), Eqn. 16
+ # use arr0 here following comments in tempo
+ self._k = (
+ (3 * c.G * self.MTOT) / (c.c**2 * arr0 * (1 - self.ecc() ** 2))
+ ).decompose()
+ # Taylor & Weisberg (1989), Eqn. 24
+ self._DR = (
+ (c.G / (c.c**2 * self.MTOT * self._arr))
+ * (3 * self._M1**2 + 6 * self._M1 * self.M2 + 2 * self.M2**2)
+ ).decompose()
+ # Damour & Deruelle (1986), Eqn. 36
+ self._er = self.ecc() * (1 + self._DR)
+ # Taylor & Weisberg (1989), Eqn. 25
+ self._DTH = (
+ (c.G / (c.c**2 * self.MTOT * self._arr))
+ * (3.5 * self._M1**2 + 6 * self._M1 * self.M2 + 2 * self.M2**2)
+ ).decompose()
+ # Damour & Deruelle (1986), Eqn. 37
+ self._eth = self.ecc() * (1 + self._DTH)
+
+ @property
+ def fe(self):
+ # Taylor & Weisberg (1989), Eqn. 19
+ return (1 + (73.0 / 24) * self.ecc() ** 2 + (37.0 / 96) * self.ecc() ** 4) * (
+ 1 - self.ecc() ** 2
+ ) ** (-7.0 / 2)
+
+ ####################
+ @property
+ def arr(self):
+ return self._arr
+
+ def d_arr_d_M2(self):
+ an = 2 * np.pi / self.pb()
+ return (
+ -9
+ * c.G**2
+ * (
+ -2.0 / 9 * self.MTOT * self.arr * c.c**2
+ + c.G * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9)
+ )
+ * self.arr
+ * (-2 * self.M2 + self.MTOT)
+ / (
+ 6 * an**2 * self.arr**5 * self.MTOT * c.c**4
+ - 18
+ * c.G**2
+ * self.MTOT
+ * c.c**2
+ * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9)
+ * self.arr
+ + 81
+ * c.G**3
+ * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9) ** 2
+ )
+ )
+
+ def d_arr_d_MTOT(self):
+ an = 2 * np.pi / self.pb()
+ return (
+ c.G
+ * self.arr
+ * (
+ -2 * self.MTOT * self.arr * c.c**2
+ + 9 * c.G * self.MTOT**2
+ - c.G * self.MTOT * self.M2
+ + c.G * self.M2**2
+ )
+ * (
+ -2 * self.MTOT * self.arr * c.c**2
+ + 27 * c.G * self.MTOT**2
+ - c.G * self.MTOT * self.M2
+ - c.G * self.M2**2
+ )
+ / self.MTOT
+ / (
+ 2.0 / 27 * an**2 * self.arr**5 * self.MTOT * c.c**4
+ - 2.0
+ / 9
+ * c.G**2
+ * self.MTOT
+ * c.c**2
+ * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9)
+ * self.arr
+ + c.G**3
+ * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9) ** 2
+ )
+ / 162
+ )
+
+ def d_arr_d_PB(self):
+ return (
+ 16.0
+ / 81
+ * np.pi**2
+ * self.arr**6
+ * self.MTOT
+ * c.c**4
+ / (
+ 0.8e1 / 0.27e2 * self.MTOT * np.pi**2 * c.c**4 * self.arr**5
+ - 2.0
+ / 9
+ * c.G**2
+ * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9)
+ * self.MTOT
+ * self.pb() ** 2
+ * c.c**2
+ * self.arr
+ + c.G**3
+ * (self.MTOT**2 - self.M2 * self.MTOT / 9 + self.M2**2 / 9) ** 2
+ * self.pb() ** 2
+ )
+ / self.pb()
+ )
+
+ ####################
+ @property
+ def k(self):
+ """Precessing rate assuming GR
+
+ Taylor and Weisberg (1989), Eqn. 16
+ """
+ return self._k
+
+ def d_k_d_MTOT(self):
+ return self.k / self.MTOT - self.k * self.d_arr_d_MTOT() / self.arr
+
+ def d_k_d_M2(self):
+ return -self.k * self.d_arr_d_M2() / self.arr
+
+ def d_k_d_ECC(self):
+ return (
+ 6
+ * (c.G * self.MTOT * self._n) ** (2.0 / 3)
+ * self.ecc()
+ / (c.c**2 * (1 - self.ecc() ** 2) ** 2)
+ )
+
+ def d_k_d_PB(self):
+ return -(self.k / self.arr) * self.d_arr_d_PB()
+
+ def d_k_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_k_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
+ return ko_func()
+
+ ####################
+ def omega(self):
+ """Longitude of periastron
+
+ omega = OM + nu * k + nu * XOMDOT / n
+
+ Like DD model, but add in any excess OMDOT from the XOMDOT term
+ """
+ return (
+ self.OM + self.nu() * self.k + self.nu() * (self.XOMDOT / (self._n * u.rad))
+ ).to(u.rad)
+
+ def d_omega_d_par(self, par):
+ """derivative for omega respect to user input Parameter.
+
+ Calculates::
+
+ if par is not 'OM','XOMDOT','MTOT','M2'
+ dOmega/dPar = k*dAe/dPar
+ k = OMDOT/n
+
+ Parameters
+ ----------
+ par : string
+ parameter name
+
+ Returns
+ -------
+ Derivative of omega respect to par
+ """
+ par_obj = getattr(self, par)
+
+ PB = self.pb()
+ OMDOT = self.OMDOT
+ OM = self.OM
+ nu = self.nu()
+ if par in ["OM", "XOMDOT", "MTOT", "M2"]:
+ # calculate the derivative directly
+ dername = f"d_omega_d_{par}"
+ return getattr(self, dername)()
+ elif par in self.orbits_cls.orbit_params:
+ # a function of both nu and k
+ d_nu_d_par = self.d_nu_d_par(par)
+ d_pb_d_par = self.d_pb_d_par(par)
+ return d_nu_d_par * self.k + d_pb_d_par * nu * OMDOT.to(
+ u.rad / u.second
+ ) / (2 * np.pi * u.rad)
+ else:
+ # For parameters only in nu
+ return (self.k * self.d_nu_d_par(par)).to(
+ OM.unit / par_obj.unit, equivalencies=u.dimensionless_angles()
+ )
+
+ def d_omega_d_MTOT(self):
+ return (self.k + self.XOMDOT / self._n / u.rad) * self.d_nu_d_MTOT().to(
+ u.rad / u.Msun, equivalencies=u.dimensionless_angles()
+ ) + self.nu() * self.d_k_d_MTOT()
+
+ def d_omega_d_M2(self):
+ return self.nu() * self.d_k_d_M2()
+
+ def d_omega_d_XOMDOT(self):
+ """Derivative.
+
+ Calculates::
+
+ dOmega/dXOMDOT = 1/n*nu
+ n = 2*pi/PB
+ dOmega/dXOMDOT = PB/2*pi*nu
+ """
+ return self.nu() / (self._n * u.rad)
+
+ ####################
+ @property
+ def SINI(self):
+ return self._SINI
+
+ def d_SINI_d_MTOT(self):
+ return (
+ -(self.MTOT * self.A1 / (self.arr * self.M2))
+ * (-1 / self.MTOT + self.d_arr_d_MTOT() / self.arr)
+ ).decompose()
+
+ def d_SINI_d_M2(self):
+ return (
+ -(self.MTOT * self.a1() / (self.arr * self.M2))
+ * (1.0 / self.M2 + self.d_arr_d_M2() / self.arr)
+ ).decompose()
+
+ def d_SINI_d_PB(self):
+ return -(self.SINI / self.arr) * self.d_arr_d_PB()
+
+ def d_SINI_d_A1(self):
+ return (self.MTOT**2 * self._n**2 / c.G) ** (1.0 / 3) / self.M2
+
+ def d_SINI_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_SINI_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
+ return ko_func()
+
+ ####################
+ @property
+ def GAMMA(self):
+ return self._GAMMA
+
+ def d_GAMMA_d_ECC(self):
+ return self.GAMMA / self.ecc()
+
+ def d_GAMMA_d_MTOT(self):
+ return (c.G / c.c**2) * (
+ (
+ (1 / (self.arr * self.MTOT))
+ - (self.MTOT + self.M2) / (self.arr * self.MTOT**2)
+ - (self.MTOT + self.M2)
+ * self.d_arr_d_MTOT()
+ / (self.arr**2 * self.MTOT)
+ )
+ * self.ecc()
+ * self.M2
+ / self._n
+ )
+
+ def d_GAMMA_d_M2(self):
+ # Note that this equation in Tempo2 may have the wrong sign
+ return -(
+ c.G
+ / c.c**2
+ * (
+ (
+ self.M2 * (self.MTOT + self.M2) * self.d_arr_d_M2() / self.arr**2
+ - (self.MTOT + 2 * self.M2) / self.arr
+ )
+ * self.ecc()
+ / self._n
+ / self.MTOT
+ ).decompose()
+ )
+
+ def d_GAMMA_d_PB(self):
+ return self.GAMMA / self.PB - (self.GAMMA / self.arr) * self.d_arr_d_PB()
+
+ def d_GAMMA_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_GAMMA_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.s / par_obj.unit
+ return ko_func()
+
+ ####################
+ @property
+ def PBDOT(self):
+ # don't add XPBDOT here: that is handled by the normal binary objects
+ return self._PBDOT
+
+ def d_PBDOT_d_MTOT(self):
+ return self.PBDOT / (self.MTOT - self.M2) - self.PBDOT / 3 / self.MTOT
+
+ def d_PBDOT_d_M2(self):
+ return self.PBDOT / self.M2 - self.PBDOT / (self.MTOT - self.M2)
+
+ def d_PBDOT_d_ECC(self):
+ return (
+ -(222 * np.pi / 5 / c.c**5)
+ * self.ecc()
+ * (c.G**5 * self._n**5 / self.MTOT) ** (1.0 / 3)
+ * self.M2
+ * (self.MTOT - self.M2)
+ * (self.ecc() ** 4 + (536.0 / 37) * self.ecc() ** 2 + 1256.0 / 111)
+ / (1 - self.ecc() ** 2) ** (9.0 / 2)
+ )
+
+ def d_PBDOT_d_PB(self):
+ return (
+ 128
+ * self.fe
+ * (4 * c.G**5 * np.pi**8 / self.PB**8 / self.MTOT) ** (1.0 / 3)
+ * (self.MTOT - self.M2)
+ / c.c**5
+ )
+
+ def d_PBDOT_d_XPBDOT(self):
+ return np.ones(len(self.tt0)) * u.Unit("")
+
+ def d_PBDOT_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_PBDOT_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
+ return ko_func()
+
+ ####################
+ # other derivatives
+ def d_E_d_MTOT(self):
+ """Eccentric anomaly has MTOT dependence through PBDOT and Kepler's equation"""
+ d_M_d_MTOT = (
+ -2 * np.pi * self.tt0**2 / (2 * self.PB**2) * self.d_PBDOT_d_MTOT()
+ )
+ return d_M_d_MTOT / (1.0 - np.cos(self.E()) * self.ecc())
+
+ def d_nu_d_MTOT(self):
+ """True anomaly nu has MTOT dependence through PBDOT"""
+ return self.d_nu_d_E() * self.d_E_d_MTOT()
+
+ ####################
+ @property
+ def OMDOT(self):
+ # don't need an explicit OMDOT here since the main precession is carried in the k term
+ return self.XOMDOT
+
+ def d_OMDOT_d_par(self, par):
+ par_obj = getattr(self, par)
+ if par == "XOMDOT":
+ return lambda: np.ones(len(self.tt0)) * (u.deg / u.yr) / par_obj.unit
+ else:
+ return lambda: np.zeros(len(self.tt0)) * (u.deg / u.yr) / par_obj.unit
+
+ ####################
+ @property
+ def DR(self):
+ return self._DR
+
+ def d_DR_d_MTOT(self):
+ return (
+ -self.DR / self.MTOT
+ - self.DR * self.d_arr_d_MTOT() / self.arr
+ + 6 * (c.G / c.c**2) / self.arr
+ )
+
+ def d_DR_d_M2(self):
+ return -self.DR * self.d_arr_d_M2() / self.arr - 2 * (
+ c.G / c.c**2
+ ) * self.M2 / (self.arr * self.MTOT)
+
+ def d_DR_d_PB(self):
+ return -(self.DR / self.arr) * self.d_arr_d_PB()
+
+ def d_DR_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_DR_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
+ return ko_func()
+
+ ####################
+ @property
+ def DTH(self):
+ return self._DTH
+
+ def d_DTH_d_MTOT(self):
+ return (
+ -self.DTH / self.MTOT
+ - self.DTH * self.d_arr_d_MTOT() / self.arr
+ + (c.G / c.c**2) * (7 * self.MTOT - self.M2) / (self.arr * self.MTOT)
+ )
+
+ def d_DTH_d_M2(self):
+ return -self.DTH * self.d_arr_d_M2() / self.arr - (c.G / c.c**2) * (
+ self.MTOT + self.M2
+ ) / (self.arr * self.MTOT)
+
+ def d_DTH_d_PB(self):
+ return -(self.DTH / self.arr) * self.d_arr_d_PB()
+
+ def d_DTH_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_DTH_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
+ return ko_func()
+
+ def er(self):
+ return self._er
+
+ def eTheta(self):
+ return self._eth
+
+ def d_er_d_MTOT(self):
+ return self.ecc() * self.d_DR_d_MTOT()
+
+ def d_er_d_M2(self):
+ return self.ecc() * self.d_DR_d_M2()
+
+ def d_eTheta_d_MTOT(self):
+ return self.ecc() * self.d_DTH_d_MTOT()
+
+ def d_eTheta_d_M2(self):
+ return self.ecc() * self.d_DTH_d_M2()
+
+ def d_beta_d_MTOT(self):
+ return (
+ -(self.beta() / (1 - self.eTheta() ** 2) ** 0.5) * self.d_eTheta_d_MTOT()
+ - (self.a1() / c.c)
+ * (1 - self.eTheta() ** 2) ** 0.5
+ * np.sin(self.omega())
+ * self.d_omega_d_MTOT()
+ )
+
+ def d_beta_d_M2(self):
+ return (
+ -(self.beta() / (1 - self.eTheta() ** 2) ** 0.5) * self.d_eTheta_d_M2()
+ - (self.a1() / c.c)
+ * (1 - self.eTheta() ** 2) ** 0.5
+ * np.sin(self.omega())
+ * self.d_omega_d_M2()
+ )
+
+ @SINI.setter
+ def SINI(self, val):
+ log.debug(
+ "DDGR model uses MTOT to derive the inclination angle. SINI will not be used."
+ )
+
+ @PBDOT.setter
+ def PBDOT(self, val):
+ log.debug("DDGR model uses MTOT to derive PBDOT. PBDOT will not be used.")
+
+ @OMDOT.setter
+ def OMDOT(self, val):
+ log.debug("DDGR model uses MTOT to derive OMDOT. OMDOT will not be used.")
+
+ @GAMMA.setter
+ def GAMMA(self, val):
+ log.debug("DDGR model uses MTOT to derive GAMMA. GAMMA will not be used.")
+
+ @DR.setter
+ def DR(self, val):
+ log.debug("DDGR model uses MTOT to derive Dr. Dr will not be used.")
+
+ @DTH.setter
+ def DTH(self, val):
+ log.debug("DDGR model uses MTOT to derive Dth. Dth will not be used.")
+
+ # wrap these properties so that we can update the PK calculations when they are set
+ @property
+ def PB(self):
+ return self._PB
+
+ @PB.setter
+ def PB(self, val):
+ self._PB = val
+ self._updatePK()
+
+ @property
+ def MTOT(self):
+ return self._MTOT
+
+ @MTOT.setter
+ def MTOT(self, val):
+ self._MTOT = val
+ self._updatePK()
+
+ @property
+ def M2(self):
+ return self._M2
+
+ @M2.setter
+ def M2(self, val):
+ self._M2 = val
+ self._updatePK()
+
+ @property
+ def A1(self):
+ return self._A1
+
+ @A1.setter
+ def A1(self, val):
+ self._A1 = val
+ self._updatePK()
+
+ @property
+ def ECC(self):
+ return self._ECC
+
+ @ECC.setter
+ def ECC(self, val):
+ self._ECC = val
+ self._updatePK()
+
+ @property
+ def A1DOT(self):
+ return self._A1DOT
+
+ @A1DOT.setter
+ def A1DOT(self, val):
+ self._A1DOT = val
+ self._updatePK()
+
+ @property
+ def EDOT(self):
+ return self._EDOT
+
+ @EDOT.setter
+ def EDOT(self, val):
+ self._EDOT = val
+ self._updatePK()
diff --git a/src/pint/models/stand_alone_psr_binaries/DDK_model.py b/src/pint/models/stand_alone_psr_binaries/DDK_model.py
index a76ddfacc..709033aae 100644
--- a/src/pint/models/stand_alone_psr_binaries/DDK_model.py
+++ b/src/pint/models/stand_alone_psr_binaries/DDK_model.py
@@ -78,6 +78,8 @@ def __init__(self, t=None, input_params=None):
# Remove unused parameter SINI
del self.param_default_value["SINI"]
self.set_param_values()
+ if input_params is not None:
+ self.update_input(param_dict=input_params)
@property
def KOM(self):
@@ -136,10 +138,7 @@ def cos_long(self):
@property
def SINI(self):
- if hasattr(self, "_tt0"):
- return np.sin(self.kin())
- else:
- return np.sin(self.KIN)
+ return np.sin(self.kin()) if hasattr(self, "_tt0") else np.sin(self.KIN)
@SINI.setter
def SINI(self, val):
@@ -172,41 +171,34 @@ def delta_kin_proper_motion(self):
return d_KIN.to(self.KIN.unit)
def kin(self):
- if self.K96:
- return self.KIN + self.delta_kin_proper_motion()
- else:
- return self.KIN
+ return self.KIN + self.delta_kin_proper_motion() if self.K96 else self.KIN
def d_SINI_d_KIN(self):
# with u.set_enabled_equivalencies(u.dimensionless_angles()):
return np.cos(self.kin()).to(u.Unit("") / self.KIN.unit)
def d_SINI_d_KOM(self):
- if self.K96:
- d_si_d_kom = (
- (-self.PMLONG_DDK * self.cos_KOM - self.PMLAT_DDK * self.sin_KOM)
- * self.tt0
- * np.cos(self.kin())
- )
- # with u.set_enabled_equivalencies(u.dimensionless_angles()):
- return d_si_d_kom.to(u.Unit("") / self.KOM.unit)
- else:
+ if not self.K96:
return np.cos(self.kin()) * u.Unit("") / self.KOM.unit
+ d_si_d_kom = (
+ (-self.PMLONG_DDK * self.cos_KOM - self.PMLAT_DDK * self.sin_KOM)
+ * self.tt0
+ * np.cos(self.kin())
+ )
+ # with u.set_enabled_equivalencies(u.dimensionless_angles()):
+ return d_si_d_kom.to(u.Unit("") / self.KOM.unit)
def d_SINI_d_T0(self):
- if self.K96:
- d_si_d_kom = -(
- -self.PMLONG_DDK * self.sin_KOM + self.PMLAT_DDK * self.cos_KOM
- )
- return d_si_d_kom.to(u.Unit("") / self.T0.unit)
- else:
+ if not self.K96:
return np.ones(len(self.tt0)) * u.Unit("") / self.T0.unit
+ d_si_d_kom = -(-self.PMLONG_DDK * self.sin_KOM + self.PMLAT_DDK * self.cos_KOM)
+ return d_si_d_kom.to(u.Unit("") / self.T0.unit)
def d_SINI_d_par(self, par):
par_obj = getattr(self, par)
try:
- ko_func = getattr(self, "d_SINI_d_" + par)
- except:
+ ko_func = getattr(self, f"d_SINI_d_{par}")
+ except Exception:
ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
return ko_func()
@@ -230,14 +222,13 @@ def d_kin_d_par(self, par):
if par == "KIN":
return np.ones_like(self.tt0)
par_obj = getattr(self, par)
- if self.K96:
- try:
- func = getattr(self, "d_kin_proper_motion_d_" + par)
- except:
- func = lambda: np.zeros(len(self.tt0)) * self.KIN / par_obj.unit
- return func()
- else:
+ if not self.K96:
return np.zeros(len(self.tt0)) * self.KIN / par_obj.unit
+ try:
+ func = getattr(self, f"d_kin_proper_motion_d_{par}")
+ except Exception:
+ func = lambda: np.zeros(len(self.tt0)) * self.KIN / par_obj.unit
+ return func()
def delta_a1_proper_motion(self):
"""The correction on a1 (projected semi-major axis)
@@ -409,10 +400,7 @@ def delta_a1_parallax(self):
"""
Reference: (Kopeikin 1995 Eq 18)
"""
- if self.K96:
- p_motion = True
- else:
- p_motion = False
+ p_motion = bool(self.K96)
a1 = self.a1_k(proper_motion=p_motion, parallax=False)
kin = self.kin()
tan_kin = np.tan(kin)
@@ -426,10 +414,7 @@ def delta_a1_parallax(self):
return delta_a1.to(a1.unit)
def d_delta_a1_parallax_d_KIN(self):
- if self.K96:
- p_motion = True
- else:
- p_motion = False
+ p_motion = bool(self.K96)
a1 = self.a1_k(proper_motion=p_motion, parallax=False)
d_a1_d_kin = self.d_a1_k_d_par("KIN", proper_motion=p_motion, parallax=False)
kin = self.kin()
@@ -444,10 +429,7 @@ def d_delta_a1_parallax_d_KIN(self):
return d_delta_a1_d_KIN.to(a1.unit / kin.unit)
def d_delta_a1_parallax_d_KOM(self):
- if self.K96:
- p_motion = True
- else:
- p_motion = False
+ p_motion = bool(self.K96)
a1 = self.a1_k(proper_motion=p_motion, parallax=False)
d_a1_d_kom = self.d_a1_k_d_par("KOM", proper_motion=p_motion, parallax=False)
kin = self.kin()
@@ -469,10 +451,7 @@ def d_delta_a1_parallax_d_KOM(self):
return d_delta_a1_d_KOM.to(a1.unit / self.KOM.unit)
def d_delta_a1_parallax_d_T0(self):
- if self.K96:
- p_motion = True
- else:
- p_motion = False
+ p_motion = bool(self.K96)
a1 = self.a1_k(proper_motion=p_motion, parallax=False)
d_a1_d_T0 = self.d_a1_k_d_par("T0", proper_motion=p_motion, parallax=False)
kin = self.kin()
@@ -565,10 +544,7 @@ def a1_k(self, proper_motion=True, parallax=True):
return a1
def a1(self):
- if self.K96:
- return self.a1_k()
- else:
- return self.a1_k(proper_motion=False)
+ return self.a1_k() if self.K96 else self.a1_k(proper_motion=False)
def d_a1_k_d_par(self, par, proper_motion=True, parallax=True):
result = super().d_a1_d_par(par)
@@ -577,7 +553,7 @@ def d_a1_k_d_par(self, par, proper_motion=True, parallax=True):
if flag:
try:
ko_func = getattr(self, ko_func_name[ii] + par)
- except:
+ except Exception:
ko_func = lambda: np.zeros(len(self.tt0)) * result.unit
result += ko_func()
return result
@@ -607,10 +583,7 @@ def omega_k(self, proper_motion=True, parallax=True):
return omega
def omega(self):
- if self.K96:
- return self.omega_k()
- else:
- return self.omega_k(proper_motion=False)
+ return self.omega_k() if self.K96 else self.omega_k(proper_motion=False)
def d_omega_k_d_par(self, par, proper_motion=True, parallax=True):
result = super().d_omega_d_par(par)
diff --git a/src/pint/models/stand_alone_psr_binaries/DDS_model.py b/src/pint/models/stand_alone_psr_binaries/DDS_model.py
new file mode 100644
index 000000000..1fc81fc56
--- /dev/null
+++ b/src/pint/models/stand_alone_psr_binaries/DDS_model.py
@@ -0,0 +1,81 @@
+"""The DDS model - Damour and Deruelle with alternate Shapiro delay parametrization."""
+import astropy.constants as c
+import astropy.units as u
+import numpy as np
+from loguru import logger as log
+
+from pint import Tsun
+
+from .DD_model import DDmodel
+
+
+class DDSmodel(DDmodel):
+ """Damour and Deruelle model with alternate Shapiro delay parameterization.
+
+ This extends the :class:`pint.models.binary_dd.BinaryDD` model with
+ :math:`SHAPMAX = -\log(1-s)` instead of just :math:`s=\sin i`, which behaves better
+ for :math:`\sin i` near 1. It does not (yet) implement the higher-order delays and lensing correction.
+
+ It supports all the parameters defined in :class:`pint.models.pulsar_binary.PulsarBinary`
+ and :class:`pint.models.binary_dd.BinaryDD` plus:
+
+ SHAPMAX
+ :math:`-\log(1-\sin i)`
+
+ It also removes:
+
+ SINI
+ use ``SHAPMAX`` instead
+
+ Parameters supported:
+
+ .. paramtable::
+ :class: pint.models.binary_dd.BinaryDDS
+
+ References
+ ----------
+ - Kramer et al. (2006), Science, 314, 97 [klm+06]_
+ - Rafikov and Lai (2006), PRD, 73, 063003 [rl06]_
+
+ .. [klm+06] https://ui.adsabs.harvard.edu/abs/2006Sci...314...97K/abstract
+ .. [rl06] https://ui.adsabs.harvard.edu/abs/2006PhRvD..73f3003R/abstract
+ """
+
+ def __init__(self, t=None, input_params=None):
+ super().__init__()
+ self.binary_name = "DDS"
+ self.param_default_value.update(
+ {
+ "SHAPMAX": 0,
+ }
+ )
+
+ # If any parameter has aliases, it should be updated
+ # self.param_aliases.update({})
+ self.binary_params = list(self.param_default_value.keys())
+ # Remove unused parameter SINI
+ del self.param_default_value["SINI"]
+ self.set_param_values()
+ if input_params is not None:
+ self.update_input(param_dict=input_params)
+
+ @property
+ def SINI(self):
+ return 1 - np.exp(-self.SHAPMAX)
+
+ @SINI.setter
+ def SINI(self, val):
+ log.debug(
+ "DDS model uses SHAPMAX as inclination parameter. SINI will not be used."
+ )
+
+ def d_SINI_d_SHAPMAX(self):
+ return np.exp(-self.SHAPMAX)
+
+ def d_SINI_d_par(self, par):
+ par_obj = getattr(self, par)
+ try:
+ ko_func = getattr(self, f"d_SINI_d_{par}")
+ except AttributeError:
+ ko_func = lambda: np.zeros(len(self.tt0)) * u.Unit("") / par_obj.unit
+ return ko_func()
diff --git a/src/pint/models/stand_alone_psr_binaries/DD_model.py b/src/pint/models/stand_alone_psr_binaries/DD_model.py
index 936199111..e9b82bdc3 100644
--- a/src/pint/models/stand_alone_psr_binaries/DD_model.py
+++ b/src/pint/models/stand_alone_psr_binaries/DD_model.py
@@ -2,6 +2,7 @@
import astropy.constants as c
import astropy.units as u
import numpy as np
+from loguru import logger as log
from pint import Tsun
@@ -70,6 +71,16 @@ def __init__(self, t=None, input_params=None):
# calculations for delays in DD model
+ @property
+ def k(self):
+ # separate this into a property so it can be calculated correctly in DDGR
+ # note that this include self.pb() in the calculation of k
+ # and self.pb() is PB + PBDOT*dt, so it can vary slightly
+ # compared to a definition that does not include PBDOT
+ # I am not certain about how this should be done
+ # but this is keeping the behavior consistent
+ return self.OMDOT.to(u.rad / u.second) / (2 * np.pi * u.rad / self.pb())
+
# DDmodel special omega.
def omega(self):
"""T. Damour and N. Deruelle (1986) equation [25]
@@ -81,13 +92,7 @@ def omega(self):
(T. Damour and N. Deruelle (1986) equation between Eq 16 Eq 17)
"""
- PB = self.pb()
- PB = PB.to("second")
- OMDOT = self.OMDOT
- OM = self.OM
- nu = self.nu()
- k = OMDOT.to(u.rad / u.second) / (2 * np.pi * u.rad / PB)
- return (OM + nu * k).to(u.rad)
+ return (self.OM + self.nu() * self.k).to(u.rad)
def d_omega_d_par(self, par):
"""derivative for omega respect to user input Parameter.
@@ -116,19 +121,18 @@ def d_omega_d_par(self, par):
OMDOT = self.OMDOT
OM = self.OM
nu = self.nu()
- k = OMDOT.to(u.rad / u.second) / (2 * np.pi * u.rad / PB)
if par in ["OM", "OMDOT"]:
dername = f"d_omega_d_{par}"
return getattr(self, dername)()
elif par in self.orbits_cls.orbit_params:
d_nu_d_par = self.d_nu_d_par(par)
d_pb_d_par = self.d_pb_d_par(par)
- return d_nu_d_par * k + d_pb_d_par * nu * OMDOT.to(u.rad / u.second) / (
- 2 * np.pi * u.rad
- )
+ return d_nu_d_par * self.k + d_pb_d_par * nu * OMDOT.to(
+ u.rad / u.second
+ ) / (2 * np.pi * u.rad)
else:
# For parameters only in nu
- return (k * self.d_nu_d_par(par)).to(
+ return (self.k * self.d_nu_d_par(par)).to(
OM.unit / par_obj.unit, equivalencies=u.dimensionless_angles()
)
@@ -838,7 +842,6 @@ def d_delayA_d_par(self, par):
decc_dpar = self.prtl_der("ecc", par)
daDelay_decc = A0 * sOmega + B0 * cOmega
-
return (
domega_dpar * daDelay_domega
+ dnu_dpar * daDelay_dnu
diff --git a/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py b/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py
index e479b9157..995a0b625 100644
--- a/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py
+++ b/src/pint/models/stand_alone_psr_binaries/ELL1H_model.py
@@ -3,27 +3,36 @@
# This Python file uses the following encoding: utf-8
import astropy.units as u
import numpy as np
+from loguru import logger as log
from .ELL1_model import ELL1BaseModel
class ELL1Hmodel(ELL1BaseModel):
- """ELL1H pulsar binary model using H3, H4 or STIGMA as shapiro delay parameter.
+ """ELL1H pulsar binary model using H3, H4 or STIGMA as shapiro delay parameters.
Note
----
- ELL1H model parameterize the shapiro delay differently compare to ELL1
- model. A fourier series expansion is used for the shapiro delay.
- Ds = -2r * (a0/2 + sum(a_k*cos(k*phi)) + sum(b_k * sin(k*phi))
- The first two harmonics are generally absorbed by ELL1 roemer delay.
- Thus, when ELL1 parameterize shapiro delay uses the series from the third
- harmonic or higher.
+ Based on Freire and Wex (2010)
+
+ The :class:`~pint.models.binary_ell1.BinaryELL1H` model parameterizes the Shapiro
+ delay differently compare to the :class:`~pint.models.binary_ell1.BinaryELL1`
+ model. A fourier series expansion is used for the Shapiro delay:
+
+ .. math::
+
+ \\Delta_S = -2r \\left( \\frac{a_0}{2} + \\Sum_k (a_k \\cos k\\phi + b_k \\sin k \phi) \\right)
+
+ The first two harmonics are generlly absorbed by the ELL1 Roemer delay.
+ Thus, :class:`~pint.models.binary_ell1.BinaryELL1H` uses the series from the third
+ harmonic and higher.
References
----------
- - Freire & Wex (2010), MNRAS, 409 (1), 199-212 [1]_
+ - Freire and Wex (2010), MNRAS, 409, 199 [1]_
.. [1] https://ui.adsabs.harvard.edu/abs/2010MNRAS.409..199F/abstract
+
"""
def __init__(self):
@@ -101,7 +110,7 @@ def _ELL1H_fourier_basis(self, k, derivative=False):
return pwr, basis_func
def fourier_component(self, stigma, k, factor_out_power=0):
- """(P. Freire and N. Wex 2010) paper Eq (13)
+ """Freire and Wex (2010), Eq (13)
Parameters
----------
@@ -117,7 +126,8 @@ def fourier_component(self, stigma, k, factor_out_power=0):
Returns
-------
- The coefficient of fourier component and the basis.
+ float
+ The coefficient of fourier component and the basis.
"""
if k != 0:
@@ -137,26 +147,23 @@ def fourier_component(self, stigma, k, factor_out_power=0):
def d_fourier_component_d_stigma(self, stigma, k, factor_out_power=0):
"""This is a method to compute the derivative of a fourier component."""
- if k != 0:
- pwr, basis_func = self._ELL1H_fourier_basis(k)
- # prevent factor out zeros.
- if stigma == 0.0 and k == factor_out_power:
- return 0.0, basis_func
- else:
- return (
- (-1) ** (pwr)
- * 2.0
- * (k - factor_out_power)
- / k
- * stigma ** (k - factor_out_power - 1),
- basis_func,
- )
- else:
- basis_func = np.cos
+
+ # prevent factor out zeros.
+ if k == 0:
# a0 is -1 * np.log(1 + stigma ** 2)
# But the in the Fourier series it is a0/2
+ return -2.0 / (1 + stigma**2.0) * stigma ** (1 - factor_out_power), np.cos
+ pwr, basis_func = self._ELL1H_fourier_basis(k)
+
+ if stigma == 0.0 and k == factor_out_power:
+ return 0.0, basis_func
+ else:
return (
- -2.0 / (1 + stigma**2.0) * stigma ** (1 - factor_out_power),
+ (-1) ** (pwr)
+ * 2.0
+ * (k - factor_out_power)
+ / k
+ * stigma ** (k - factor_out_power - 1),
basis_func,
)
@@ -189,9 +196,8 @@ def ELL1H_shapiro_delay_fourier_harms(
):
"""Fourier series harms of shapiro delay.
- One can select the start term and end term, in other
- words, a part the fourier series term can be selected.
- (P. Freire and N. Wex 2010) paper Eq (10)
+ One can select the start term and end term.
+ Freire and Wex (2010), Eq. (10)
Parameters
----------
@@ -203,16 +209,15 @@ def ELL1H_shapiro_delay_fourier_harms(
Returns
-------
- The summation of harmonics
+ np.ndarray
+ The summation of harmonics
"""
harms = np.zeros((len(selected_harms), len(phi)))
# To prevent factor out zeros
- if stigma == 0.0:
- if selected_harms.min() < factor_out_power:
- raise ValueError(
- "Can not factor_out_power can not bigger than"
- " the selected_harms."
- )
+ if stigma == 0.0 and selected_harms.min() < factor_out_power:
+ raise ValueError(
+ "Can not factor_out_power can not bigger than" " the selected_harms."
+ )
for ii, k in enumerate(selected_harms):
coeff, basis_func = self.fourier_component(
stigma, k, factor_out_power=factor_out_power
@@ -229,25 +234,23 @@ def d_ELL1H_fourier_harms_d_par(
par_obj = getattr(self, par)
try:
df_func = getattr(self, df_name)
- except:
+ except AttributeError:
return 0.0 * u.Unit(None) / par_obj.Unit
d_harms = np.zeros((len(selected_harms), len(phi)))
# To prevent factor out zeros
- if stigma == 0.0:
- if selected_harms.min() < factor_out_power:
- raise ValueError(
- "Can not factor_out_power can not bigger than"
- " the selected_harms."
- )
+ if stigma == 0.0 and selected_harms.min() < factor_out_power:
+ raise ValueError(
+ "Can not factor_out_power can not bigger than" " the selected_harms."
+ )
for ii, k in enumerate(selected_harms):
coeff, basis_func = df_func(stigma, k, factor_out_power=factor_out_power)
d_harms[ii] = coeff * basis_func(k * phi)
return np.sum(d_harms, axis=0)
def delayS3p_H3_STIGMA_approximate(self, H3, stigma, end_harm=6):
- """Shapiro delay third harmonics or higher harms
+ """Shapiro delay using third or higher harmonics, appropriate for medium inclinations.
- defined in the (P. Freire and N. Wex 2010) paper Eq (19).
+ defined in Freire and Wex (2010), Eq (19).
"""
Phi = self.Phi()
selected_harms = np.arange(3, end_harm + 1)
@@ -257,7 +260,7 @@ def delayS3p_H3_STIGMA_approximate(self, H3, stigma, end_harm=6):
return -2.0 * H3 * sum_fharms
def d_delayS3p_H3_STIGMA_approximate_d_H3(self, H3, stigma, end_harm=6):
- """derivative of delayS3p_H3_STIGMA with respect to H3"""
+ """derivative of delayS3p_H3_STIGMA with respect to H3"""
Phi = self.Phi()
selected_harms = np.arange(3, end_harm + 1)
sum_fharms = self.ELL1H_shapiro_delay_fourier_harms(
@@ -266,7 +269,7 @@ def d_delayS3p_H3_STIGMA_approximate_d_H3(self, H3, stigma, end_harm=6):
return -2.0 * sum_fharms
def d_delayS3p_H3_STIGMA_approximate_d_STIGMA(self, H3, stigma, end_harm=6):
- """derivative of delayS3p_H3_STIGMA with respect to STIGMA"""
+ """derivative of delayS3p_H3_STIGMA with respect to STIGMA"""
Phi = self.Phi()
selected_harms = np.arange(3, end_harm + 1)
sum_d_fharms = self.d_ELL1H_fourier_harms_d_par(
@@ -275,7 +278,7 @@ def d_delayS3p_H3_STIGMA_approximate_d_STIGMA(self, H3, stigma, end_harm=6):
return -2.0 * H3 * sum_d_fharms
def d_delayS3p_H3_STIGMA_approximate_d_Phi(self, H3, stigma, end_harm=6):
- """derivative of delayS3p_H3_STIGMA with respect to Phi"""
+ """derivative of delayS3p_H3_STIGMA with respect to Phi"""
Phi = self.Phi()
selected_harms = np.arange(3, end_harm + 1)
sum_d_fharms = self.d_ELL1H_fourier_harms_d_par(
@@ -284,9 +287,9 @@ def d_delayS3p_H3_STIGMA_approximate_d_Phi(self, H3, stigma, end_harm=6):
return -2.0 * H3 * sum_d_fharms
def delayS3p_H3_STIGMA_exact(self, H3, stigma, end_harm=None):
- """Shapiro delay third harmonics or higher harms
+ """Shapiro delay (3rd hamonic and higher) using the exact form for very high inclinations.
- exact format defined in the P. Freire and N. Wex 2010 paper Eq (28).
+ Defined in Freire and Wex (2010), Eq (28).
"""
Phi = self.Phi()
lognum = 1 + stigma**2 - 2 * stigma * np.sin(Phi)
@@ -302,10 +305,7 @@ def delayS3p_H3_STIGMA_exact(self, H3, stigma, end_harm=None):
)
def d_delayS3p_H3_STIGMA_exact_d_H3(self, H3, stigma, end_harm=None):
- """derivative of Shapiro delay third harmonics or higher harms
-
- exact format with respect to H3
- """
+ """derivative of exact Shapiro delay (3rd hamonic and higher) with respect to H3"""
Phi = self.Phi()
lognum = 1 + stigma**2 - 2 * stigma * np.sin(Phi)
return (
@@ -319,10 +319,7 @@ def d_delayS3p_H3_STIGMA_exact_d_H3(self, H3, stigma, end_harm=None):
)
def d_delayS3p_H3_STIGMA_exact_d_STIGMA(self, H3, stigma, end_harm=None):
- """derivative of Shapiro delay third harmonics or higher harms
-
- exact format with respect to STIGMA
- """
+ """derivative of exact Shapiro delay (3rd hamonic and higher) with respect to STIGMA"""
Phi = self.Phi()
lognum = 1 + stigma**2 - 2 * stigma * np.sin(Phi)
return (
@@ -338,10 +335,7 @@ def d_delayS3p_H3_STIGMA_exact_d_STIGMA(self, H3, stigma, end_harm=None):
)
def d_delayS3p_H3_STIGMA_exact_d_Phi(self, H3, stigma, end_harm=None):
- """derivative of Shapiro delay third harmonics or higher harms
-
- exact format with respect to STIGMA
- """
+ """derivative of exact Shapiro delay (3rd hamonic and higher) with respect to phase"""
Phi = self.Phi()
lognum = 1 + stigma**2 - 2 * stigma * np.sin(Phi)
return (
@@ -352,7 +346,10 @@ def d_delayS3p_H3_STIGMA_exact_d_Phi(self, H3, stigma, end_harm=None):
)
def delayS_H3_STIGMA_exact(self, H3, stigma, end_harm=None):
- """P. Freire and N. Wex 2010 paper Eq (29)"""
+ """Shapiro delay (including all harmonics) using the exact form for very high inclinations.
+
+ Defined in Freire and Wex (2010), Eq (29).
+ """
Phi = self.Phi()
lognum = 1 + stigma**2 - 2 * stigma * np.sin(Phi)
return -2 * H3 / stigma**3 * np.log(lognum)
@@ -392,8 +389,7 @@ def d_delayS_d_par(self, par):
stigma = 0.0
else:
raise NotImplementedError(
- "ELL1H did not implemented %s parameter"
- " set yet." % str(self.fit_params)
+ f"ELL1H fit not implemented for {self.fit_params} parameters"
)
d_ds_func_name_base = f"d_{self.ds_func.__name__}_d_"
diff --git a/src/pint/models/stand_alone_psr_binaries/ELL1_model.py b/src/pint/models/stand_alone_psr_binaries/ELL1_model.py
index e90535f55..b3fd8133d 100644
--- a/src/pint/models/stand_alone_psr_binaries/ELL1_model.py
+++ b/src/pint/models/stand_alone_psr_binaries/ELL1_model.py
@@ -137,12 +137,196 @@ def d_Phi_d_par(self, par):
except Exception:
return self.d_M_d_par(par)
+ def delayI(self):
+ """Inverse time delay formula.
+
+ The treatment is similar to the one
+ in DD model (T. Damour & N. Deruelle (1986) equation [46-52])::
+
+ Dre = a1*(sin(Phi)+eps1/2*sin(2*Phi)+eps1/2*cos(2*Phi))
+ Drep = dDre/dt
+ Drepp = d^2 Dre/dt^2
+ nhat = dPhi/dt = 2pi/pb
+ nhatp = d^2Phi/dt^2 = 0
+ Dre(t-Dre(t-Dre(t))) = Dre(Phi) - Drep(Phi)*nhat*Dre(t-Dre(t))
+ = Dre(Phi) - Drep(Phi)*nhat*(Dre(Phi)-Drep(Phi)*nhat*Dre(t))
+ + 1/2 (Drepp(u)*nhat^2 + Drep(u) * nhat * nhatp) * (Dre(t)-...)^2
+ = Dre(Phi)(1 - nhat*Drep(Phi) + (nhat*Drep(Phi))^2
+ + 1/2*nhat^2* Dre*Drepp)
+ """
+ Dre = self.delayR()
+ Drep = self.Drep()
+ Drepp = self.Drepp()
+ PB = self.pb().to("second")
+ nhat = 2 * np.pi / self.pb()
+ return (
+ Dre
+ * (1 - nhat * Drep + (nhat * Drep) ** 2 + 1.0 / 2 * nhat**2 * Dre * Drepp)
+ ).decompose()
+
+ def nhat(self):
+ return 2 * np.pi / self.pb()
+
+ def d_nhat_d_par(self, par):
+ return -2 * np.pi / self.pb() ** 2 * self.d_pb_d_par(par)
+
+ def d_delayI_d_par(self, par):
+ """Delay derivative.
+
+ Computes::
+
+ delayI = Dre*(1 - nhat*Drep + (nhat*Drep)**2 + 1.0/2*nhat**2*Dre*Drepp)
+ d_delayI_d_par = d_delayI_d_Dre * d_Dre_d_par + d_delayI_d_Drep * d_Drep_d_par +
+ d_delayI_d_Drepp * d_Drepp_d_par + d_delayI_d_nhat * d_nhat_d_par
+ """
+ Dre = self.delayR()
+ Drep = self.Drep()
+ Drepp = self.Drepp()
+ PB = self.pb().to("second")
+ nhat = 2 * np.pi / self.pb()
+
+ d_delayI_d_Dre = (
+ 1 - nhat * Drep + (nhat * Drep) ** 2 + 1.0 / 2 * nhat**2 * Dre * Drepp
+ ) + Dre * 1.0 / 2 * nhat**2 * Drepp
+ d_delayI_d_Drep = -Dre * nhat + 2 * (nhat * Drep) * nhat * Dre
+ d_delayI_d_Drepp = 1.0 / 2 * (nhat * Dre) ** 2
+ d_delayI_d_nhat = Dre * (-Drep + 2 * (nhat * Drep) * Drep + nhat * Dre * Drepp)
+ d_nhat_d_par = self.prtl_der("nhat", par)
+ d_Dre_d_par = self.d_Dre_d_par(par)
+ d_Drep_d_par = self.d_Drep_d_par(par)
+ d_Drepp_d_par = self.d_Drepp_d_par(par)
+
+ return (
+ d_delayI_d_Dre * d_Dre_d_par
+ + d_delayI_d_Drep * d_Drep_d_par
+ + d_delayI_d_Drepp * d_Drepp_d_par
+ + d_delayI_d_nhat * d_nhat_d_par
+ )
+
+ def ELL1_om(self):
+ # arctan(om)
+ om = np.arctan2(self.eps1(), self.eps2())
+ return om.to(u.deg, equivalencies=u.dimensionless_angles())
+
+ def ELL1_ecc(self):
+ return np.sqrt(self.eps1() ** 2 + self.eps2() ** 2)
+
+ def ELL1_T0(self):
+ return self.TASC + self.pb() / (2 * np.pi) * (
+ np.arctan(self.eps1() / self.eps2())
+ ).to(u.Unit(""), equivalencies=u.dimensionless_angles())
+
+ ###############################
+ def d_delayR_da1(self):
+ """ELL1 Roemer delay in proper time divided by a1/c, including third order corrections
+
+ typo corrected from Zhu et al., following:
+ https://github.com/nanograv/tempo/blob/master/src/bnryell1.f
+ """
+ Phi = self.Phi()
+ eps1 = self.eps1()
+ eps2 = self.eps2()
+ return (
+ np.sin(Phi)
+ + 0.5 * (eps2 * np.sin(2 * Phi) - eps1 * np.cos(2 * Phi))
+ - (1.0 / 8)
+ * (
+ 5 * eps2**2 * np.sin(Phi)
+ - 3 * eps2**2 * np.sin(3 * Phi)
+ - 2 * eps2 * eps1 * np.cos(Phi)
+ + 6 * eps2 * eps1 * np.cos(3 * Phi)
+ + 3 * eps1**2 * np.sin(Phi)
+ + 3 * eps1**2 * np.sin(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ 5 * eps2**3 * np.sin(2 * Phi)
+ + 3 * eps1**2 * eps2 * np.sin(2 * Phi)
+ - 6 * eps1 * eps2**2 * np.cos(2 * Phi)
+ - 4 * eps1**3 * np.cos(2 * Phi)
+ - 4 * eps2**3 * np.sin(4 * Phi)
+ + 12 * eps1**2 * eps2 * np.sin(4 * Phi)
+ + 12 * eps1 * eps2**2 * np.cos(4 * Phi)
+ - 4 * eps1**3 * np.cos(4 * Phi)
+ )
+ )
+
+ def d_d_delayR_dPhi_da1(self):
+ """d (ELL1 Roemer delay)/dPhi in proper time divided by a1/c"""
+ Phi = self.Phi()
+ eps1 = self.eps1()
+ eps2 = self.eps2()
+ return (
+ np.cos(Phi)
+ + eps1 * np.sin(2 * Phi)
+ + eps2 * np.cos(2 * Phi)
+ - (1.0 / 8)
+ * (
+ 5 * eps2**2 * np.cos(Phi)
+ - 9 * eps2**2 * np.cos(3 * Phi)
+ + 2 * eps1 * eps2 * np.sin(Phi)
+ - 18 * eps1 * eps2 * np.sin(3 * Phi)
+ + 3 * eps1**2 * np.cos(Phi)
+ + 9 * eps1**2 * np.cos(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ 10 * eps2**3 * np.cos(2 * Phi)
+ + 6 * eps1**2 * eps2 * np.cos(2 * Phi)
+ + 12 * eps1 * eps2**2 * np.sin(2 * Phi)
+ + 8 * eps1**3 * np.sin(2 * Phi)
+ - 16 * eps2**3 * np.cos(4 * Phi)
+ + 48 * eps1**2 * eps2 * np.cos(4 * Phi)
+ - 48 * eps1 * eps2**2 * np.sin(4 * Phi)
+ + 16 * eps1**3 * np.sin(4 * Phi)
+ )
+ )
+
+ def d_dd_delayR_dPhi_da1(self):
+ """d^2 (ELL1 Roemer delay)/dPhi^2 in proper time divided by a1/c"""
+ Phi = self.Phi()
+ eps1 = self.eps1()
+ eps2 = self.eps2()
+ return (
+ -np.sin(Phi)
+ + 2 * eps1 * np.cos(2 * Phi)
+ - 2 * eps2 * np.sin(2 * Phi)
+ - (1.0 / 8)
+ * (
+ -5 * eps2**2 * np.sin(Phi)
+ + 27 * eps2**2 * np.sin(3 * Phi)
+ + 2 * eps1 * eps2 * np.cos(Phi)
+ - 54 * eps1 * eps2 * np.cos(3 * Phi)
+ - 3 * eps1**2 * np.sin(Phi)
+ - 27 * eps1**2 * np.sin(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ -20 * eps2**3 * np.sin(2 * Phi)
+ - 12 * eps1**2 * eps2 * np.sin(2 * Phi)
+ + 24 * eps1 * eps2**2 * np.cos(2 * Phi)
+ + 16 * eps1**3 * np.cos(2 * Phi)
+ + 64 * eps2**3 * np.sin(4 * Phi)
+ - 192 * eps1**2 * eps2 * np.sin(4 * Phi)
+ - 192 * eps1 * eps2**2 * np.cos(4 * Phi)
+ + 64 * eps1**3 * np.cos(4 * Phi)
+ )
+ )
+
+ def delayR(self):
+ """ELL1 Roemer delay in proper time.
+ Include terms up to third order in eccentricity
+ Zhu et al. (2019), Eqn. 1
+ Fiore et al. (2023), Eqn. 4
+ """
+ return ((self.a1() / c.c) * self.d_delayR_da1()).decompose()
+
def d_Dre_d_par(self, par):
"""Derivative computation.
Computes::
- Dre = delayR = a1/c.c*(sin(phi) - 0.5* eps1*cos(2*phi) + 0.5* eps2*sin(2*phi))
+ Dre = delayR = a1/c.c*(sin(phi) - 0.5* eps1*cos(2*phi) + 0.5* eps2*sin(2*phi) + ...)
d_Dre_d_par = d_a1_d_par /c.c*(sin(phi) - 0.5* eps1*cos(2*phi) + 0.5* eps2*sin(2*phi)) +
d_Dre_d_Phi * d_Phi_d_par + d_Dre_d_eps1*d_eps1_d_par + d_Dre_d_eps2*d_eps2_d_par
"""
@@ -153,18 +337,57 @@ def d_Dre_d_par(self, par):
d_a1_d_par = self.prtl_der("a1", par)
d_Dre_d_Phi = self.Drep()
d_Phi_d_par = self.prtl_der("Phi", par)
- d_Dre_d_eps1 = a1 / c.c * (-0.5 * np.cos(2 * Phi))
- d_Dre_d_eps2 = a1 / c.c * (0.5 * np.sin(2 * Phi))
+ d_Dre_d_eps1 = (
+ a1
+ / c.c
+ * (
+ -0.5 * np.cos(2 * Phi)
+ - (1.0 / 8)
+ * (
+ -2 * eps2 * np.cos(Phi)
+ + 6 * eps2 * np.cos(3 * Phi)
+ + 6 * eps1 * np.sin(Phi)
+ + 6 * eps1 * np.sin(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ 6 * eps1 * eps2 * np.sin(2 * Phi)
+ - 6 * eps2**2 * np.cos(2 * Phi)
+ - 12 * eps1**2 * np.cos(2 * Phi)
+ + 24 * eps1 * eps2 * np.sin(4 * Phi)
+ + 12 * eps2**2 * np.cos(4 * Phi)
+ - 12 * eps1**2 * np.cos(4 * Phi)
+ )
+ )
+ )
- with u.set_enabled_equivalencies(u.dimensionless_angles()):
- d_Dre_d_par = (
- d_a1_d_par
- / c.c
+ d_Dre_d_eps2 = (
+ a1
+ / c.c
+ * (
+ 0.5 * np.sin(2 * Phi)
+ - (1.0 / 8)
* (
- np.sin(Phi)
- - 0.5 * eps1 * np.cos(2 * Phi)
- + 0.5 * eps2 * np.sin(2 * Phi)
+ -2 * eps1 * np.cos(Phi)
+ + 6 * eps1 * np.cos(3 * Phi)
+ + 10 * eps2 * np.sin(Phi)
+ - 6 * eps2 * np.sin(3 * Phi)
)
+ - (1.0 / 12)
+ * (
+ 15 * eps2**2 * np.sin(2 * Phi)
+ + 3 * eps1**2 * np.sin(2 * Phi)
+ - 12 * eps1 * eps2 * np.cos(2 * Phi)
+ - 12 * eps2**2 * np.sin(4 * Phi)
+ + 12 * eps1**2 * np.sin(4 * Phi)
+ + 24 * eps1 * eps2 * np.cos(4 * Phi)
+ )
+ )
+ )
+
+ with u.set_enabled_equivalencies(u.dimensionless_angles()):
+ d_Dre_d_par = (
+ d_a1_d_par / c.c * self.d_delayR_da1()
+ d_Dre_d_Phi * d_Phi_d_par
+ d_Dre_d_eps1 * self.prtl_der("eps1", par)
+ d_Dre_d_eps2 * self.prtl_der("eps2", par)
@@ -174,27 +397,18 @@ def d_Dre_d_par(self, par):
def Drep(self):
"""dDre/dPhi"""
a1 = self.a1()
- eps1 = self.eps1()
- eps2 = self.eps2()
- Phi = self.Phi()
# Here we are using full d Dre/dPhi. But Tempo and Tempo2 ELL1 model
- # does not have the last two terms. This will result a difference in
+ # does not have terms beyond the first one. This will result a difference in
# the order of magnitude of 1e-8s level.
- return (
- a1
- / c.c
- * (np.cos(Phi) + eps1 * np.sin(2.0 * Phi) + eps2 * np.cos(2.0 * Phi))
- )
+ return a1 / c.c * self.d_d_delayR_dPhi_da1()
def d_Drep_d_par(self, par):
"""Derivative computation.
Computes::
- Drep = d_Dre_d_Phi = a1/c.c*(cos(Phi) + eps1 * sin(Phi) + eps2 * cos(Phi))
- d_Drep_d_par = d_a1_d_par /c.c*(cos(Phi) + eps1 * sin(Phi) + eps2 * cos(Phi)) +
- d_Drep_d_Phi * d_Phi_d_par + d_Drep_d_eps1*d_eps1_d_par +
- d_Drep_d_eps2*d_eps2_d_par
+ Drep = d_Dre_d_Phi = a1/c.c*(cos(Phi) + eps1 * sin(Phi) + eps2 * cos(Phi) + ...)
+ d_Drep_d_par = ...
"""
a1 = self.a1()
Phi = self.Phi()
@@ -203,14 +417,57 @@ def d_Drep_d_par(self, par):
d_a1_d_par = self.prtl_der("a1", par)
d_Drep_d_Phi = self.Drepp()
d_Phi_d_par = self.prtl_der("Phi", par)
- d_Drep_d_eps1 = a1 / c.c * np.sin(2.0 * Phi)
- d_Drep_d_eps2 = a1 / c.c * np.cos(2.0 * Phi)
+ d_Drep_d_eps1 = (
+ a1
+ / c.c
+ * (
+ np.sin(2.0 * Phi)
+ - (1.0 / 8)
+ * (
+ 6 * eps1 * np.cos(Phi)
+ + 18 * eps1 * np.cos(3 * Phi)
+ + 2 * eps2 * np.sin(Phi)
+ - 18 * eps2 * np.sin(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ 12 * eps1 * eps2 * np.cos(2 * Phi)
+ + 12 * eps2**2 * np.sin(2 * Phi)
+ + 16 * eps1**2 * np.sin(2 * Phi)
+ + 96 * eps1 * eps2 * np.cos(4 * Phi)
+ - 48 * eps2**2 * np.sin(4 * Phi)
+ + 48 * eps1**2 * np.sin(4 * Phi)
+ )
+ )
+ )
+
+ d_Drep_d_eps2 = (
+ a1
+ / c.c
+ * (
+ np.cos(2.0 * Phi)
+ - (1.0 / 8)
+ * (
+ 2 * eps1 * np.sin(Phi)
+ - 18 * eps1 * np.sin(3 * Phi)
+ + 10 * eps2 * np.cos(Phi)
+ - 18 * eps2 * np.cos(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ 30 * eps2**2 * np.cos(2 * Phi)
+ + 6 * eps1**2 * np.cos(2 * Phi)
+ + 24 * eps1 * eps2 * np.sin(2 * Phi)
+ - 48 * eps2**2 * np.cos(4 * Phi)
+ + 48 * eps1**2 * np.cos(4 * Phi)
+ - 96 * eps1 * eps2 * np.sin(4 * Phi)
+ )
+ )
+ )
with u.set_enabled_equivalencies(u.dimensionless_angles()):
d_Drep_d_par = (
- d_a1_d_par
- / c.c
- * (np.cos(Phi) + eps1 * np.sin(2.0 * Phi) + eps2 * np.cos(2.0 * Phi))
+ d_a1_d_par / c.c * self.d_d_delayR_dPhi_da1()
+ d_Drep_d_Phi * d_Phi_d_par
+ d_Drep_d_eps1 * self.prtl_der("eps1", par)
+ d_Drep_d_eps2 * self.prtl_der("eps2", par)
@@ -218,28 +475,17 @@ def d_Drep_d_par(self, par):
return d_Drep_d_par
def Drepp(self):
+ """d^2Dre/dPhi^2"""
a1 = self.a1()
- eps1 = self.eps1()
- eps2 = self.eps2()
- Phi = self.Phi()
- return (
- a1
- / c.c
- * (
- -np.sin(Phi)
- + 2.0 * (eps1 * np.cos(2.0 * Phi) - eps2 * np.sin(2.0 * Phi))
- )
- )
+ return a1 / c.c * self.d_dd_delayR_dPhi_da1()
def d_Drepp_d_par(self, par):
"""Derivative computation
Computes::
- Drepp = d_Drep_d_Phi = a1/c.c*(-sin(Phi) + 2.0* (eps1 * cos(2.0*Phi) - eps2 * sin(2.0*Phi)))
- d_Drepp_d_par = d_a1_d_par /c.c*(-sin(Phi) + 2.0* (eps1 * cos(2.0*Phi) - eps2 * sin(2.0*Phi))) +
- d_Drepp_d_Phi * d_Phi_d_par + d_Drepp_d_eps1*d_eps1_d_par +
- d_Drepp_d_eps2*d_eps2_d_par
+ Drepp = d_Drep_d_Phi = ...
+ d_Drepp_d_par = ...
"""
a1 = self.a1()
Phi = self.Phi()
@@ -252,116 +498,85 @@ def d_Drepp_d_par(self, par):
* (
-np.cos(Phi)
- 4.0 * (eps1 * np.sin(2.0 * Phi) + eps2 * np.cos(2.0 * Phi))
+ - (1.0 / 8)
+ * (
+ -5 * eps2**2 * np.cos(Phi)
+ + 81 * eps2**2 * np.cos(3 * Phi)
+ - 2 * eps1 * eps2 * np.sin(Phi)
+ + 162 * eps1 * eps2 * np.sin(3 * Phi)
+ - 3 * eps1**2 * np.cos(Phi)
+ - 81 * eps1**2 * np.cos(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ -40 * eps2**3 * np.cos(2 * Phi)
+ - 24 * eps1**2 * eps2 * np.cos(2 * Phi)
+ - 48 * eps1 * eps2**2 * np.sin(2 * Phi)
+ - 32 * eps1**3 * np.sin(2 * Phi)
+ + 256 * eps2**3 * np.cos(4 * Phi)
+ - 768 * eps1**2 * eps2 * np.cos(4 * Phi)
+ + 768 * eps1 * eps2**2 * np.sin(4 * Phi)
+ - 256 * eps1**3 * np.sin(4 * Phi)
+ )
)
)
- d_Phi_d_par = self.prtl_der("Phi", par)
- d_Drepp_d_eps1 = a1 / c.c * 2.0 * np.cos(2.0 * Phi)
- d_Drepp_d_eps2 = -a1 / c.c * 2.0 * np.sin(2.0 * Phi)
- with u.set_enabled_equivalencies(u.dimensionless_angles()):
- d_Drepp_d_par = (
- d_a1_d_par
- / c.c
+ d_Phi_d_par = self.prtl_der("Phi", par)
+ d_Drepp_d_eps1 = (
+ a1
+ / c.c
+ * (
+ 2.0 * np.cos(2.0 * Phi)
+ - (1.0 / 8)
* (
- -np.sin(Phi)
- + 2.0 * (eps1 * np.cos(2.0 * Phi) - eps2 * np.sin(2.0 * Phi))
+ -6 * eps1 * np.sin(Phi)
+ - 54 * eps1 * np.sin(3 * Phi)
+ + 2 * eps2 * np.cos(Phi)
+ - 54 * eps2 * np.cos(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ -24 * eps1 * eps2 * np.sin(2 * Phi)
+ + 24 * eps2**2 * np.cos(2 * Phi)
+ + 48 * eps1**2 * np.cos(2 * Phi)
+ - 384 * eps1 * eps2 * np.sin(4 * Phi)
+ - 192 * eps2**2 * np.cos(4 * Phi)
+ + 192 * eps1**2 * np.cos(4 * Phi)
)
- + d_Drepp_d_Phi * d_Phi_d_par
- + d_Drepp_d_eps1 * self.prtl_der("eps1", par)
- + d_Drepp_d_eps2 * self.prtl_der("eps2", par)
)
- return d_Drepp_d_par
-
- def delayR(self):
- """ELL1 Roemer delay in proper time. Ch. Lange et al 2001 eq. A6"""
- Phi = self.Phi()
- return (
- self.a1()
+ )
+ d_Drepp_d_eps2 = (
+ a1
/ c.c
* (
- np.sin(Phi)
- + 0.5 * (self.eps2() * np.sin(2 * Phi) - self.eps1() * np.cos(2 * Phi))
+ -2.0 * np.sin(2.0 * Phi)
+ - (1.0 / 8)
+ * (
+ 2 * eps1 * np.cos(Phi)
+ - 54 * eps1 * np.cos(3 * Phi)
+ - 10 * eps2 * np.sin(Phi)
+ + 54 * eps2 * np.sin(3 * Phi)
+ )
+ - (1.0 / 12)
+ * (
+ -60 * eps2**2 * np.sin(2 * Phi)
+ - 12 * eps1**2 * np.sin(2 * Phi)
+ + 48 * eps1 * eps2 * np.cos(2 * Phi)
+ + 192 * eps2**2 * np.sin(4 * Phi)
+ - 192 * eps1**2 * np.sin(4 * Phi)
+ - 384 * eps1 * eps2 * np.cos(4 * Phi)
+ )
)
- ).decompose()
-
- def delayI(self):
- """Inverse time delay formula.
-
- The treatment is similar to the one
- in DD model (T. Damour & N. Deruelle (1986) equation [46-52])::
-
- Dre = a1*(sin(Phi)+eps1/2*sin(2*Phi)+eps1/2*cos(2*Phi))
- Drep = dDre/dt
- Drepp = d^2 Dre/dt^2
- nhat = dPhi/dt = 2pi/pb
- nhatp = d^2Phi/dt^2 = 0
- Dre(t-Dre(t-Dre(t))) = Dre(Phi) - Drep(Phi)*nhat*Dre(t-Dre(t))
- = Dre(Phi) - Drep(Phi)*nhat*(Dre(Phi)-Drep(Phi)*nhat*Dre(t))
- + 1/2 (Drepp(u)*nhat^2 + Drep(u) * nhat * nhatp) * (Dre(t)-...)^2
- = Dre(Phi)(1 - nhat*Drep(Phi) + (nhat*Drep(Phi))^2
- + 1/2*nhat^2* Dre*Drepp)
- """
- Dre = self.delayR()
- Drep = self.Drep()
- Drepp = self.Drepp()
- PB = self.pb().to("second")
- nhat = 2 * np.pi / self.pb()
- return (
- Dre
- * (1 - nhat * Drep + (nhat * Drep) ** 2 + 1.0 / 2 * nhat**2 * Dre * Drepp)
- ).decompose()
-
- def nhat(self):
- return 2 * np.pi / self.pb()
-
- def d_nhat_d_par(self, par):
- return -2 * np.pi / self.pb() ** 2 * self.d_pb_d_par(par)
-
- def d_delayI_d_par(self, par):
- """Delay derivative.
-
- Computes::
-
- delayI = Dre*(1 - nhat*Drep + (nhat*Drep)**2 + 1.0/2*nhat**2*Dre*Drepp)
- d_delayI_d_par = d_delayI_d_Dre * d_Dre_d_par + d_delayI_d_Drep * d_Drep_d_par +
- d_delayI_d_Drepp * d_Drepp_d_par + d_delayI_d_nhat * d_nhat_d_par
- """
- Dre = self.delayR()
- Drep = self.Drep()
- Drepp = self.Drepp()
- PB = self.pb().to("second")
- nhat = 2 * np.pi / self.pb()
-
- d_delayI_d_Dre = (
- 1 - nhat * Drep + (nhat * Drep) ** 2 + 1.0 / 2 * nhat**2 * Dre * Drepp
- ) + Dre * 1.0 / 2 * nhat**2 * Drepp
- d_delayI_d_Drep = -Dre * nhat + 2 * (nhat * Drep) * nhat * Dre
- d_delayI_d_Drepp = 1.0 / 2 * (nhat * Dre) ** 2
- d_delayI_d_nhat = Dre * (-Drep + 2 * (nhat * Drep) * Drep + nhat * Dre * Drepp)
- d_nhat_d_par = self.prtl_der("nhat", par)
- d_Dre_d_par = self.d_Dre_d_par(par)
- d_Drep_d_par = self.d_Drep_d_par(par)
- d_Drepp_d_par = self.d_Drepp_d_par(par)
-
- return (
- d_delayI_d_Dre * d_Dre_d_par
- + d_delayI_d_Drep * d_Drep_d_par
- + d_delayI_d_Drepp * d_Drepp_d_par
- + d_delayI_d_nhat * d_nhat_d_par
)
- def ELL1_om(self):
- # arctan(om)
- om = np.arctan2(self.eps1(), self.eps2())
- return om.to(u.deg, equivalencies=u.dimensionless_angles())
-
- def ELL1_ecc(self):
- return np.sqrt(self.eps1() ** 2 + self.eps2() ** 2)
-
- def ELL1_T0(self):
- return self.TASC + self.pb() / (2 * np.pi) * (
- np.arctan(self.eps1() / self.eps2())
- ).to(u.Unit(""), equivalencies=u.dimensionless_angles())
+ with u.set_enabled_equivalencies(u.dimensionless_angles()):
+ d_Drepp_d_par = (
+ d_a1_d_par / c.c * self.d_dd_delayR_dPhi_da1()
+ + d_Drepp_d_Phi * d_Phi_d_par
+ + d_Drepp_d_eps1 * self.prtl_der("eps1", par)
+ + d_Drepp_d_eps2 * self.prtl_der("eps2", par)
+ )
+ return d_Drepp_d_par
class ELL1model(ELL1BaseModel):
diff --git a/src/pint/models/stand_alone_psr_binaries/ELL1k_model.py b/src/pint/models/stand_alone_psr_binaries/ELL1k_model.py
index 78138be9a..0fe785988 100644
--- a/src/pint/models/stand_alone_psr_binaries/ELL1k_model.py
+++ b/src/pint/models/stand_alone_psr_binaries/ELL1k_model.py
@@ -28,6 +28,7 @@ def __init__(self):
self.param_default_value.pop("EPS1DOT")
self.param_default_value.pop("EPS2DOT")
+ self.param_default_value.pop("EDOT")
self.param_default_value.update(
{"OMDOT": u.Quantity(0, "deg/year"), "LNEDOT": u.Quantity(0, "1/year")}
)
diff --git a/src/pint/models/stand_alone_psr_binaries/binary_generic.py b/src/pint/models/stand_alone_psr_binaries/binary_generic.py
index d19f5e1a7..c62f174a3 100644
--- a/src/pint/models/stand_alone_psr_binaries/binary_generic.py
+++ b/src/pint/models/stand_alone_psr_binaries/binary_generic.py
@@ -255,6 +255,7 @@ def d_binarydelay_d_par(self, par):
raise AttributeError(
f"Can not find parameter {par} in {self.binary_name} model"
)
+
# Get first derivative in the delay derivative function
result = self.d_binarydelay_d_par_funcs[0](par)
if len(self.d_binarydelay_d_par_funcs) > 1:
@@ -317,7 +318,6 @@ def prtl_der(self, y, x):
xU = U[1]
# Call derivative functions
derU = yU / xU
-
if hasattr(self, f"d_{y}_d_{x}"):
dername = f"d_{y}_d_{x}"
result = getattr(self, dername)()
@@ -355,7 +355,6 @@ def compute_eccentric_anomaly(self, eccentricity, mean_anomaly):
e = np.longdouble(eccentricity).value
else:
e = eccentricity
- print(f"e: {e}")
if any(e < 0) or any(e >= 1):
raise ValueError("Eccentricity should be in the range of [0,1).")
@@ -381,9 +380,9 @@ def get_tt0(self, barycentricTOA):
def ecc(self):
"""Calculate eccentricity with EDOT"""
- ECC = self.ECC
- EDOT = self.EDOT
- return ECC + (self.tt0 * EDOT).decompose()
+ if hasattr(self, "_tt0"):
+ return self.ECC + (self.tt0 * self.EDOT).decompose()
+ return self.ECC
def d_ecc_d_T0(self):
result = np.empty(len(self.tt0))
@@ -397,7 +396,7 @@ def d_ecc_d_EDOT(self):
return self.tt0
def a1(self):
- return self.A1 + self.tt0 * self.A1DOT
+ return self.A1 + self.tt0 * self.A1DOT if hasattr(self, "_tt0") else self.A1
def d_a1_d_A1(self):
return np.longdouble(np.ones(len(self.tt0))) * u.Unit("")
@@ -534,6 +533,7 @@ def d_E_d_par(self, par):
else:
E = self.E()
return np.zeros(len(self.tt0)) * E.unit / par_obj.unit
+ return func()
def nu(self):
"""True anomaly (Ae)"""
diff --git a/src/pint/models/stand_alone_psr_binaries/binary_orbits.py b/src/pint/models/stand_alone_psr_binaries/binary_orbits.py
index 96eea8de1..7ee4f6b81 100644
--- a/src/pint/models/stand_alone_psr_binaries/binary_orbits.py
+++ b/src/pint/models/stand_alone_psr_binaries/binary_orbits.py
@@ -26,8 +26,7 @@ def orbit_phase(self):
"""Orbital phase (between zero and two pi)."""
orbits = self.orbits()
norbits = np.array(np.floor(orbits), dtype=np.compat.long)
- phase = (orbits - norbits) * 2 * np.pi * u.rad
- return phase
+ return (orbits - norbits) * 2 * np.pi * u.rad
def pbprime(self):
"""Derivative of binary period with respect to time."""
@@ -47,7 +46,7 @@ def d_orbits_d_par(self, par):
"""
par_obj = getattr(self, par)
try:
- func = getattr(self, "d_orbits_d_" + par)
+ func = getattr(self, f"d_orbits_d_{par}")
except AttributeError:
def func():
@@ -60,7 +59,7 @@ def d_pbprime_d_par(self, par):
"""Derivative of binary period with respect to some parameter."""
par_obj = getattr(self, par)
try:
- func = getattr(self, "d_pbprime_d_" + par)
+ func = getattr(self, f"d_pbprime_d_{par}")
except AttributeError:
def func():
@@ -72,13 +71,12 @@ def func():
def __getattr__(self, name):
try:
return super().__getattribute__(name)
- except AttributeError:
+ except AttributeError as e:
p = super().__getattribute__("_parent")
if p is None:
raise AttributeError(
- "'%s' object has no attribute '%s'."
- % (self.__class__.__name__, name)
- )
+ f"'{self.__class__.__name__}' object has no attribute '{name}'."
+ ) from e
else:
return self._parent.__getattribute__(name)
@@ -102,10 +100,9 @@ def orbits(self):
PB = self.PB.to("second")
PBDOT = self.PBDOT
XPBDOT = self.XPBDOT
- orbits = (
+ return (
self.tt0 / PB - 0.5 * (PBDOT + XPBDOT) * (self.tt0 / PB) ** 2
).decompose()
- return orbits
def pbprime(self):
"""Derivative of binary period with respect to time."""
@@ -151,6 +148,8 @@ def d_pbprime_d_PBDOT(self):
return self.tt0
def d_pbprime_d_T0(self):
+ if not np.isscalar(self.PBDOT):
+ return -self.PBDOT
result = np.empty(len(self.tt0))
result.fill(-self.PBDOT.value)
return result * u.Unit(self.PBDOT.unit)
@@ -164,10 +163,9 @@ def __init__(self, parent, orbit_params=["FB0"]):
# add the rest of FBX parameters.
indices = set()
for k in self.binary_params:
- if re.match(r"FB\d+", k) is not None:
- if k not in self.orbit_params:
- self.orbit_params += [k]
- indices.add(int(k[2:]))
+ if re.match(r"FB\d+", k) is not None and k not in self.orbit_params:
+ self.orbit_params += [k]
+ indices.add(int(k[2:]))
if indices != set(range(len(indices))):
raise ValueError(
f"Indices must be 0 up to some number k without gaps "
@@ -177,8 +175,8 @@ def __init__(self, parent, orbit_params=["FB0"]):
def _FBXs(self):
FBXs = [0 * u.Unit("")]
ii = 0
- while "FB" + str(ii) in self.orbit_params:
- FBXs.append(getattr(self, "FB" + str(ii)))
+ while f"FB{ii}" in self.orbit_params:
+ FBXs.append(getattr(self, f"FB{ii}"))
ii += 1
return FBXs
@@ -198,21 +196,21 @@ def pbdot_orbit(self):
return -(self.pbprime() ** 2) * orbit_freq_dot
def d_orbits_d_par(self, par):
- if re.match(r"FB\d+", par) is not None:
- result = self.d_orbits_d_FBX(par)
- else:
- result = super().d_orbits_d_par(par)
- return result
+ return (
+ self.d_orbits_d_FBX(par)
+ if re.match(r"FB\d+", par) is not None
+ else super().d_orbits_d_par(par)
+ )
def d_orbits_d_FBX(self, FBX):
par = getattr(self, FBX)
ii = 0
FBXs = [0 * u.Unit("")]
- while "FB" + str(ii) in self.orbit_params:
- if "FB" + str(ii) != FBX:
- FBXs.append(0.0 * getattr(self, "FB" + str(ii)).unit)
+ while f"FB{ii}" in self.orbit_params:
+ if f"FB{ii}" != FBX:
+ FBXs.append(0.0 * getattr(self, f"FB{ii}").unit)
else:
- FBXs.append(1.0 * getattr(self, "FB" + str(ii)).unit)
+ FBXs.append(1.0 * getattr(self, f"FB{ii}").unit)
break
ii += 1
d_orbits = taylor_horner(self.tt0, FBXs) / par.unit
@@ -222,11 +220,11 @@ def d_pbprime_d_FBX(self, FBX):
par = getattr(self, FBX)
ii = 0
FBXs = [0 * u.Unit("")]
- while "FB" + str(ii) in self.orbit_params:
- if "FB" + str(ii) != FBX:
- FBXs.append(0.0 * getattr(self, "FB" + str(ii)).unit)
+ while f"FB{ii}" in self.orbit_params:
+ if f"FB{ii}" != FBX:
+ FBXs.append(0.0 * getattr(self, f"FB{ii}").unit)
else:
- FBXs.append(1.0 * getattr(self, "FB" + str(ii)).unit)
+ FBXs.append(1.0 * getattr(self, f"FB{ii}").unit)
break
ii += 1
d_FB = taylor_horner_deriv(self.tt0, FBXs, 1) / par.unit
@@ -234,8 +232,8 @@ def d_pbprime_d_FBX(self, FBX):
def d_pbprime_d_par(self, par):
par_obj = getattr(self, par)
- if re.match(r"FB\d+", par) is not None:
- result = self.d_pbprime_d_FBX(par)
- else:
- result = np.zeros(len(self.tt0)) * u.second / par_obj.unit
- return result
+ return (
+ self.d_pbprime_d_FBX(par)
+ if re.match(r"FB\d+", par) is not None
+ else np.zeros(len(self.tt0)) * u.second / par_obj.unit
+ )
diff --git a/src/pint/models/tcb_conversion.py b/src/pint/models/tcb_conversion.py
new file mode 100644
index 000000000..c4adccfed
--- /dev/null
+++ b/src/pint/models/tcb_conversion.py
@@ -0,0 +1,172 @@
+"""TCB to TDB conversion of a timing model."""
+
+import numpy as np
+
+from pint.models.parameter import MJDParameter
+from loguru import logger as log
+
+__all__ = [
+ "IFTE_K",
+ "scale_parameter",
+ "transform_mjd_parameter",
+ "convert_tcb_to_tdb",
+]
+
+# These constants are taken from Irwin & Fukushima 1999.
+# These are the same as the constants used in tempo2 as of 10 Feb 2023.
+IFTE_MJD0 = np.longdouble("43144.0003725")
+IFTE_KM1 = np.longdouble("1.55051979176e-8")
+IFTE_K = 1 + IFTE_KM1
+
+
+def scale_parameter(model, param, n, backwards):
+ """Scale a parameter x by a power of IFTE_K
+ x_tdb = x_tcb * IFTE_K**n
+
+ The power n depends on the "effective dimensionality" of
+ the parameter as it appears in the timing model. Some examples
+ are given bellow:
+
+ 1. F0 has effective dimensionality of frequency and n = 1
+ 2. F1 has effective dimensionality of frequency^2 and n = 2
+ 3. A1 has effective dimensionality of time because it appears as
+ A1/c in the timing model. Therefore, its n = -1
+ 4. DM has effective dimensionality of frequency because it appears
+ as DM*DMconst in the timing model. Therefore, its n = 1
+ 5. PBDOT is dimensionless and has n = 0. i.e., it is not scaled.
+
+ Parameter
+ ---------
+ model : pint.models.timing_model.TimingModel
+ The timing model
+ param : str
+ The parameter name to be converted
+ n : int
+ The power of IFTE_K in the scaling factor
+ backwards : bool
+ Whether to do TDB to TCB conversion.
+ """
+ assert isinstance(n, int), "The power must be an integer."
+
+ p = -1 if backwards else 1
+
+ factor = IFTE_K ** (p * n)
+
+ if hasattr(model, param) and getattr(model, param).quantity is not None:
+ par = getattr(model, param)
+ par.value *= factor
+ if par.uncertainty_value is not None:
+ par.uncertainty_value *= factor
+
+
+def transform_mjd_parameter(model, param, backwards):
+ """Convert an MJD from TCB to TDB or vice versa.
+ t_tdb = (t_tcb - IFTE_MJD0) / IFTE_K + IFTE_MJD0
+ t_tcb = (t_tdb - IFTE_MJD0) * IFTE_K + IFTE_MJD0
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+ The timing model
+ param : str
+ The parameter name to be converted
+ backwards : bool
+ Whether to do TDB to TCB conversion.
+ """
+ factor = IFTE_K if backwards else 1 / IFTE_K
+ tref = IFTE_MJD0
+
+ if hasattr(model, param) and getattr(model, param).quantity is not None:
+ par = getattr(model, param)
+ assert isinstance(par, MJDParameter)
+
+ par.value = (par.value - tref) * factor + tref
+ if par.uncertainty_value is not None:
+ par.uncertainty_value *= factor
+
+
+def convert_tcb_tdb(model, backwards=False):
+ """This function performs a partial conversion of a model
+ specified in TCB to TDB. While this should be sufficient as
+ a starting point, the resulting parameters are only approximate
+ and the model should be re-fit.
+
+ This is based on the `transform` plugin of tempo2.
+
+ The following parameters are converted to TDB:
+ 1. Spin frequency, its derivatives and spin epoch
+ 2. Sky coordinates, proper motion and the position epoch
+ 3. DM, DM derivatives and DM epoch
+ 4. Keplerian binary parameters and FB1
+
+ The following parameters are NOT converted although they are
+ in fact affected by the TCB to TDB conversion:
+ 1. Parallax
+ 2. TZRMJD and TZRFRQ
+ 2. DMX parameters
+ 3. Solar wind parameters
+ 4. Binary post-Keplerian parameters including Shapiro delay
+ parameters (except FB1)
+ 5. Jumps and DM Jumps
+ 6. FD parameters
+ 7. EQUADs
+ 8. Red noise parameters including FITWAVES, powerlaw red noise and
+ powerlaw DM noise parameters
+
+ Parameters
+ ----------
+ model : pint.models.timing_model.TimingModel
+ Timing model to be converted.
+ backwards : bool
+ Whether to do TDB to TCB conversion. The default is TCB to TDB.
+ """
+
+ target_units = "TCB" if backwards else "TDB"
+
+ if model.UNITS.value == target_units or (
+ model.UNITS.value is None and not backwards
+ ):
+ log.warning("The input par file is already in the target units. Doing nothing.")
+ return
+
+ log.warning(
+ "Converting this timing model from TCB to TDB. "
+ "Please note that the TCB to TDB conversion is only approximate and "
+ "the resulting timing model should be re-fit to get reliable results."
+ )
+
+ if "Spindown" in model.components:
+ for n, Fn_par in model.get_prefix_mapping("F").items():
+ scale_parameter(model, Fn_par, n + 1, backwards)
+
+ transform_mjd_parameter(model, "PEPOCH", backwards)
+
+ if "AstrometryEquatorial" in model.components:
+ scale_parameter(model, "PMRA", 1, backwards)
+ scale_parameter(model, "PMDEC", 1, backwards)
+ elif "AstrometryEcliptic" in model.components:
+ scale_parameter(model, "PMELAT", 1, backwards)
+ scale_parameter(model, "PMELONG", 1, backwards)
+ transform_mjd_parameter(model, "POSEPOCH", backwards)
+
+ # Although DM has the unit pc/cm^3, the quantity that enters
+ # the timing model is DMconst*DM, which has dimensions
+ # of frequency. Hence, DM and its derivatives will be
+ # scaled by IFTE_K**(i+1).
+ if "DispersionDM" in model.components:
+ scale_parameter(model, "DM", 1, backwards)
+ for n, DMn_par in model.get_prefix_mapping("DM").items():
+ scale_parameter(model, DMn_par, n + 1, backwards)
+ transform_mjd_parameter(model, "DMEPOCH", backwards)
+
+ if hasattr(model, "BINARY") and getattr(model, "BINARY").value is not None:
+ transform_mjd_parameter(model, "T0", backwards)
+ transform_mjd_parameter(model, "TASC", backwards)
+ scale_parameter(model, "PB", -1, backwards)
+ scale_parameter(model, "FB0", 1, backwards)
+ scale_parameter(model, "FB1", 2, backwards)
+ scale_parameter(model, "A1", -1, backwards)
+
+ model.UNITS.value = target_units
+
+ model.validate(allow_tcb=backwards)
diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py
index d371c5aa1..1bd05bfa2 100644
--- a/src/pint/models/timing_model.py
+++ b/src/pint/models/timing_model.py
@@ -26,6 +26,8 @@
See :ref:`Timing Models` for more details on how PINT's timing models work.
"""
+
+
import abc
import copy
import inspect
@@ -78,7 +80,7 @@
"MissingBinaryError",
"UnknownBinaryModel",
]
-# Parameters or lines in parfiles we don't understand but shouldn't
+# Parameters or lines in par files we don't understand but shouldn't
# complain about. These are still passed to components so that they
# can use them if they want to.
#
@@ -86,22 +88,19 @@
# errors in the par file.
#
# Comparisons with keywords in par file lines is done in a case insensitive way.
-ignore_params = set(
- [
- "TRES",
- "TZRMJD",
- "TZRFRQ",
- "TZRSITE",
- "NITS",
- "IBOOT",
- "CHI2R",
- "MODE",
- "PLANET_SHAPIRO2",
- # 'NE_SW', 'NE_SW2',
- ]
-)
-
-ignore_prefix = set(["DMXF1_", "DMXF2_", "DMXEP_"]) # DMXEP_ for now.
+ignore_params = {
+ "TRES",
+ "TZRMJD",
+ "TZRFRQ",
+ "TZRSITE",
+ "NITS",
+ "IBOOT",
+ "CHI2R",
+ "MODE",
+ "PLANET_SHAPIRO2",
+}
+
+ignore_prefix = {"DMXF1_", "DMXF2_", "DMXEP_"}
DEFAULT_ORDER = [
"astrometry",
@@ -184,12 +183,11 @@ class TimingModel:
removed with methods on this object, and for many of them additional
parameters in families (``DMXEP_1234``) can be added.
- Parameters in a TimingModel object are listed in the ``model.params`` and
- ``model.params_ordered`` objects. Each Parameter can be set as free or
- frozen using its ``.frozen`` attribute, and a list of the free parameters
- is available through the ``model.free_params`` property; this can also
- be used to set which parameters are free. Several methods are available
- to get and set some or all parameters in the forms of dictionaries.
+ Parameters in a TimingModel object are listed in the ``model.params`` object.
+ Each Parameter can be set as free or frozen using its ``.frozen`` attribute,
+ and a list of the free parameters is available through the ``model.free_params``
+ property; this can also be used to set which parameters are free. Several methods
+ are available to get and set some or all parameters in the forms of dictionaries.
TimingModel objects also support a number of functions for computing
various things like orbital phase, and barycentric versions of TOAs,
@@ -232,8 +230,8 @@ class TimingModel:
- Derivatives of delay and phase respect to parameter for fitting toas.
Each timing parameters are stored as TimingModel attribute in the type of
- :class:`~pint.models.parameter.Parameter` delay or phase and its derivatives are implemented
- as TimingModel Methods.
+ :class:`~pint.models.parameter.Parameter` delay or phase and its derivatives
+ are implemented as TimingModel Methods.
Attributes
----------
@@ -359,10 +357,11 @@ def __repr__(self):
def __str__(self):
return self.as_parfile()
- def validate(self):
+ def validate(self, allow_tcb=False):
"""Validate component setup.
- The checks include required parameters and parameter values.
+ The checks include required parameters and parameter values, and component types.
+ See also: :func:`pint.models.timing_model.TimingModel.validate_component_types`.
"""
if self.DILATEFREQ.value:
warn("PINT does not support 'DILATEFREQ Y'")
@@ -373,27 +372,109 @@ def validate(self):
if self.T2CMETHOD.value not in [None, "IAU2000B"]: # FIXME: really?
warn("PINT only supports 'T2CMETHOD IAU2000B'")
self.T2CMETHOD.value = "IAU2000B"
- if self.UNITS.value not in [None, "TDB"]:
- if self.UNITS.value == "TCB":
- error_message = """The TCB timescale is not supported by PINT. (PINT only supports 'UNITS TDB'.)
- See https://nanograv-pint.readthedocs.io/en/latest/explanation.html#time-scales for an explanation
- on different timescales. The par file can be converted from TCB to TDB using the `transform`
- plugin of TEMPO2 like so:
- $ tempo2 -gr transform J1234+6789_tcb.par J1234+6789_tdb.par tdb
+
+ if self.UNITS.value not in [None, "TDB", "TCB"]:
+ error_message = f"PINT only supports 'UNITS TDB'. The given timescale '{self.UNITS.value}' is invalid."
+ raise ValueError(error_message)
+ elif self.UNITS.value == "TCB":
+ if not allow_tcb:
+ error_message = """The TCB timescale is not fully supported by PINT.
+ PINT only supports 'UNITS TDB' internally. See https://nanograv-pint.readthedocs.io/en/latest/explanation.html#time-scales
+ for an explanation on different timescales. A TCB par file can be
+ converted to TDB using the `tcb2tdb` command like so:
+
+ $ tcb2tdb J1234+6789_tcb.par J1234+6789_tdb.par
+
+ However, this conversion is not exact and a fit must be performed to obtain
+ reliable results. Note that PINT only supports writing TDB par files.
"""
+ raise ValueError(error_message)
else:
- error_message = f"PINT only supports 'UNITS TDB'. The given timescale '{self.UNITS.value}' is invalid."
- raise ValueError(error_message)
+ log.warning(
+ "PINT does not support 'UNITS TCB' internally. Reading this par file nevertheless "
+ "because the `allow_tcb` option was given. This `TimingModel` object should not be "
+ "used for anything except converting to TDB."
+ )
if not self.START.frozen:
- warn("START cannot be unfrozen...setting START.frozen to True")
+ warn("START cannot be unfrozen... Setting START.frozen to True")
self.START.frozen = True
if not self.FINISH.frozen:
- warn("FINISH cannot be unfrozen...setting FINISH.frozen to True")
+ warn("FINISH cannot be unfrozen... Setting FINISH.frozen to True")
self.FINISH.frozen = True
for cp in self.components.values():
cp.validate()
+ self.validate_component_types()
+
+ def validate_component_types(self):
+ """Physically motivated validation of a timing model. This method checks the
+ compatibility of different model components when used together.
+
+ This function throws an error if multiple deterministic components that model
+ the same effect are used together (e.g. :class:`pint.models.astrometry.AstrometryEquatorial`
+ and :class:`pint.models.astrometry.AstrometryEcliptic`). It emits a warning if
+ a deterministic component and a stochastic component that model the same effect
+ are used together (e.g. :class:`pint.models.noise_model.PLDMNoise`
+ and :class:`pint.models.dispersion_model.DispersionDMX`). It also requires that
+ one and only one :class:`pint.models.spindown.SpindownBase` component is present
+ in a timing model.
+ """
+
+ def num_components_of_type(type):
+ return len(
+ list(filter(lambda c: isinstance(c, type), self.components.values()))
+ )
+
+ from pint.models.spindown import SpindownBase
+
+ assert (
+ num_components_of_type(SpindownBase) == 1
+ ), "Model must have one and only one spindown component (Spindown or another subclass of SpindownBase)."
+
+ from pint.models.astrometry import Astrometry
+
+ assert (
+ num_components_of_type(Astrometry) <= 1
+ ), "Model can have at most one Astrometry component."
+
+ from pint.models.solar_system_shapiro import SolarSystemShapiro
+
+ if num_components_of_type(SolarSystemShapiro) == 1:
+ assert (
+ num_components_of_type(Astrometry) == 1
+ ), "Model cannot have SolarSystemShapiro component without an Astrometry component."
+
+ from pint.models.pulsar_binary import PulsarBinary
+
+ has_binary_attr = hasattr(self, "BINARY") and self.BINARY.value
+ if has_binary_attr:
+ assert (
+ num_components_of_type(PulsarBinary) == 1
+ ), "BINARY attribute is set but no PulsarBinary component found."
+ assert (
+ num_components_of_type(PulsarBinary) <= 1
+ ), "Model can have at most one PulsarBinary component."
+
+ from pint.models.solar_wind_dispersion import SolarWindDispersionBase
+
+ assert (
+ num_components_of_type(SolarWindDispersionBase) <= 1
+ ), "Model can have at most one solar wind dispersion component."
+
+ from pint.models.dispersion_model import DispersionDMX
+ from pint.models.wave import Wave
+ from pint.models.noise_model import PLRedNoise, PLDMNoise
+
+ if num_components_of_type((DispersionDMX, PLDMNoise)) > 1:
+ log.warning(
+ "DispersionDMX and PLDMNoise are being used together. They are two ways of modelling the same effect."
+ )
+ if num_components_of_type((Wave, PLRedNoise)) > 1:
+ log.warning(
+ "Wave and PLRedNoise are being used together. They are two ways of modelling the same effect."
+ )
+
# def __str__(self):
# result = ""
# comps = self.components
@@ -418,54 +499,60 @@ def __getattr__(self, name):
)
@property_exists
- def params(self):
- """List of all parameter names in this model and all its components (order is arbitrary)."""
- # FIXME: any reason not to just use params_ordered here?
- p = self.top_level_params
- for cp in self.components.values():
- p = p + cp.params
- return p
+ def params_ordered(self):
+ """List of all parameter names in this model and all its components.
+ This is the same as `params`."""
+
+ # Historically, this was different from `params` because Python
+ # dictionaries were unordered until Python 3.7. Now there is no reason for
+ # them to be different.
+
+ warn(
+ "`TimingModel.params_ordered` is now deprecated and may be removed in the future. "
+ "Use `TimingModel.params` instead. It gives the same output as `TimingModel.params_ordered`.",
+ DeprecationWarning,
+ )
+
+ return self.params
@property_exists
- def params_ordered(self):
+ def params(self):
"""List of all parameter names in this model and all its components, in a sensible order."""
+
# Define the order of components in the list
# Any not included will be printed between the first and last set.
# FIXME: make order completely canonical (sort components by name?)
+
start_order = ["astrometry", "spindown", "dispersion"]
last_order = ["jump_delay"]
compdict = self.get_components_by_category()
used_cats = set()
pstart = copy.copy(self.top_level_params)
for cat in start_order:
- if cat in compdict:
- cp = compdict[cat]
- for cpp in cp:
- pstart += cpp.params
- used_cats.add(cat)
- else:
+ if cat not in compdict:
continue
-
+ cp = compdict[cat]
+ for cpp in cp:
+ pstart += cpp.params
+ used_cats.add(cat)
pend = []
for cat in last_order:
- if cat in compdict:
- cp = compdict[cat]
- for cpp in cp:
- pend += cpp.parms
- used_cats.add(cat)
- else:
+ if cat not in compdict:
continue
+ cp = compdict[cat]
+ for cpp in cp:
+ pend += cpp.parms
+ used_cats.add(cat)
# Now collect any components that haven't already been included in the list
pmid = []
for cat in compdict:
if cat in used_cats:
continue
- else:
- cp = compdict[cat]
- for cpp in cp:
- pmid += cpp.params
- used_cats.add(cat)
+ cp = compdict[cat]
+ for cpp in cp:
+ pmid += cpp.params
+ used_cats.add(cat)
return pstart + pmid + pend
@@ -473,7 +560,7 @@ def params_ordered(self):
def free_params(self):
"""List of all the free parameters in the timing model. Can be set to change which are free.
- These are ordered as ``self.params_ordered`` does.
+ These are ordered as ``self.params`` does.
Upon setting, order does not matter, and aliases are accepted.
ValueError is raised if a parameter is not recognized.
@@ -481,7 +568,7 @@ def free_params(self):
On setting, parameter aliases are converted with
:func:`pint.models.timing_model.TimingModel.match_param_aliases`.
"""
- return [p for p in self.params_ordered if not getattr(self, p).frozen]
+ return [p for p in self.params if not getattr(self, p).frozen]
@free_params.setter
def free_params(self, params):
@@ -542,7 +629,7 @@ def get_params_dict(self, which="free", kind="quantity"):
if which == "free":
ps = self.free_params
elif which == "all":
- ps = self.params_ordered
+ ps = self.params
else:
raise ValueError("get_params_dict expects which to be 'all' or 'free'")
c = OrderedDict()
@@ -569,7 +656,7 @@ def get_params_of_component_type(self, component_type):
-------
list
"""
- component_type_list_str = "{}_list".format(component_type)
+ component_type_list_str = f"{component_type}_list"
if hasattr(self, component_type_list_str):
component_type_list = getattr(self, component_type_list_str)
return [
@@ -600,17 +687,14 @@ def set_param_uncertainties(self, fitp):
"""Set the model parameters to the value contained in the input dict."""
for k, v in fitp.items():
p = getattr(self, k)
- if isinstance(v, u.Quantity):
- p.uncertainty = v
- else:
- p.uncertainty = v * p.units
+ p.uncertainty = v if isinstance(v, u.Quantity) else v * p.units
@property_exists
def components(self):
"""All the components in a dictionary indexed by name."""
comps = {}
for ct in self.component_types:
- for cp in getattr(self, ct + "_list"):
+ for cp in getattr(self, f"{ct}_list"):
comps[cp.__class__.__name__] = cp
return comps
@@ -698,13 +782,11 @@ def orbital_phase(self, barytimes, anom="mean", radians=True):
elif anom.lower() == "true":
anoms = bbi.nu() # can be negative
else:
- raise ValueError("anom='%s' is not a recognized type of anomaly" % anom)
+ raise ValueError(f"anom='{anom}' is not a recognized type of anomaly")
# Make sure all angles are between 0-2*pi
anoms = np.remainder(anoms.value, 2 * np.pi)
- if radians: # return with radian units
- return anoms * u.rad
- else: # return as unitless cycles from 0-1
- return anoms / (2 * np.pi)
+ # return with radian units or return as unitless cycles from 0-1
+ return anoms * u.rad if radians else anoms / (2 * np.pi)
def conjunction(self, baryMJD):
"""Return the time(s) of the first superior conjunction(s) after baryMJD.
@@ -752,14 +834,7 @@ def funct(t):
scs = []
for bt in bts:
# Make 11 times over one orbit after bt
- if self.PB.value is not None:
- pb = self.PB.value
- elif self.FB0.quantity is not None:
- pb = (1 / self.FB0.quantity).to("day").value
- else:
- raise AttributeError(
- "Neither PB nor FB0 is present in the timing model."
- )
+ pb = self.pb()[0].to_value("day")
ts = np.linspace(bt, bt + pb, 11)
# Compute the true anomalies and omegas for those times
nus = self.orbital_phase(ts, anom="true")
@@ -867,13 +942,13 @@ def d_phase_d_delay_funcs(self):
def get_deriv_funcs(self, component_type, derivative_type=""):
"""Return dictionary of derivative functions."""
- # TODO, this function can be a more generical function collector.
+ # TODO, this function can be a more generic function collector.
deriv_funcs = defaultdict(list)
- if not derivative_type == "":
+ if derivative_type != "":
derivative_type += "_"
- for cp in getattr(self, component_type + "_list"):
+ for cp in getattr(self, f"{component_type}_list"):
try:
- df = getattr(cp, derivative_type + "deriv_funcs")
+ df = getattr(cp, f"{derivative_type}deriv_funcs")
except AttributeError:
continue
for k, v in df.items():
@@ -916,13 +991,11 @@ def get_component_type(self, component):
comp_base = inspect.getmro(component.__class__)
if comp_base[-2].__name__ != "Component":
raise TypeError(
- "Class '%s' is not a Component type class."
- % component.__class__.__name__
+ f"Class '{component.__class__.__name__}' is not a Component type class."
)
elif len(comp_base) < 3:
raise TypeError(
- "'%s' class is not a subclass of 'Component' class."
- % component.__class__.__name__
+ f"'{component.__class__.__name__}' class is not a subclass of 'Component' class."
)
else:
comp_type = comp_base[-3].__name__
@@ -950,17 +1023,16 @@ def map_component(self, component):
comps = self.components
if isinstance(component, str):
if component not in list(comps.keys()):
- raise AttributeError("No '%s' in the timing model." % component)
+ raise AttributeError(f"No '{component}' in the timing model.")
comp = comps[component]
- else: # When component is an component instance.
- if component not in list(comps.values()):
- raise AttributeError(
- "No '%s' in the timing model." % component.__class__.__name__
- )
- else:
- comp = component
+ elif component in list(comps.values()):
+ comp = component
+ else:
+ raise AttributeError(
+ f"No '{component.__class__.__name__}' in the timing model."
+ )
comp_type = self.get_component_type(comp)
- host_list = getattr(self, comp_type + "_list")
+ host_list = getattr(self, f"{comp_type}_list")
order = host_list.index(comp)
return comp, order, host_list, comp_type
@@ -979,9 +1051,9 @@ def add_component(
If true, add a duplicate component. Default is False.
"""
comp_type = self.get_component_type(component)
+ cur_cps = []
if comp_type in self.component_types:
- comp_list = getattr(self, comp_type + "_list")
- cur_cps = []
+ comp_list = getattr(self, f"{comp_type}_list")
for cp in comp_list:
# If component order is not defined.
cp_order = (
@@ -991,32 +1063,28 @@ def add_component(
# Check if the component has been added already.
if component.__class__ in (x.__class__ for x in comp_list):
log.warning(
- "Component '%s' is already present but was added again."
- % component.__class__.__name__
+ f"Component '{component.__class__.__name__}' is already present but was added again."
)
if not force:
raise ValueError(
- "Component '%s' is already present and will not be "
- "added again. To force add it, use force=True option."
- % component.__class__.__name__
+ f"Component '{component.__class__.__name__}' is already present and will not be "
+ f"added again. To force add it, use force=True option."
)
else:
self.component_types.append(comp_type)
- cur_cps = []
-
# link new component to TimingModel
component._parent = self
- # If the categore is not in the order list, it will be added to the end.
+ # If the category is not in the order list, it will be added to the end.
if component.category not in order:
- new_cp = tuple((len(order) + 1, component))
+ new_cp = len(order) + 1, component
else:
- new_cp = tuple((order.index(component.category), component))
+ new_cp = order.index(component.category), component
# add new component
cur_cps.append(new_cp)
cur_cps.sort(key=lambda x: x[0])
new_comp_list = [c[1] for c in cur_cps]
- setattr(self, comp_type + "_list", new_comp_list)
+ setattr(self, f"{comp_type}_list", new_comp_list)
# Set up components
if setup:
self.setup()
@@ -1048,8 +1116,8 @@ def _locate_param_host(self, param):
list of tuples
All possible components that host the target parameter. The first
element is the component object that have the target parameter, the
- second one is the parameter object. If it is a prefix-style parameter
- , it will return one example of such parameter.
+ second one is the parameter object. If it is a prefix-style parameter,
+ it will return one example of such parameter.
"""
result_comp = []
for cp_name, cp in self.components.items():
@@ -1081,8 +1149,8 @@ def add_param_from_top(self, param, target_component, setup=False):
Parameters
----------
- param: str
- Parameter name
+ param: pint.models.parameter.Parameter
+ Parameter instance
target_component: str
Parameter host component name. If given as "" it would add
parameter to the top level `TimingModel` class
@@ -1092,12 +1160,12 @@ def add_param_from_top(self, param, target_component, setup=False):
if target_component == "":
setattr(self, param.name, param)
self.top_level_params += [param.name]
- else:
- if target_component not in list(self.components.keys()):
- raise AttributeError(
- "Can not find component '%s' in " "timing model." % target_component
- )
+ elif target_component in list(self.components.keys()):
self.components[target_component].add_param(param, setup=setup)
+ else:
+ raise AttributeError(
+ f"Can not find component '{target_component}' in " "timing model."
+ )
def remove_param(self, param):
"""Remove a parameter from timing model.
@@ -1109,7 +1177,7 @@ def remove_param(self, param):
"""
param_map = self.get_params_mapping()
if param not in param_map:
- raise AttributeError("Can not find '%s' in timing model." % param)
+ raise AttributeError(f"Can not find '{param}' in timing model.")
if param_map[param] == "timing_model":
delattr(self, param)
self.top_level_params.remove(param)
@@ -1119,10 +1187,8 @@ def remove_param(self, param):
self.setup()
def get_params_mapping(self):
- """Report whick component each parameter name comes from."""
- param_mapping = {}
- for p in self.top_level_params:
- param_mapping[p] = "timing_model"
+ """Report which component each parameter name comes from."""
+ param_mapping = {p: "timing_model" for p in self.top_level_params}
for cp in list(self.components.values()):
for pp in cp.params:
param_mapping[pp] = cp.__class__.__name__
@@ -1145,7 +1211,7 @@ def get_prefix_mapping(self, prefix):
Returns
-------
dict
- A dictionary with prefix pararameter real index as key and parameter
+ A dictionary with prefix parameter real index as key and parameter
name as value.
"""
for cp in self.components.values():
@@ -1231,13 +1297,12 @@ def delay(self, toas, cutoff_component="", include_last=True):
idx = len(self.DelayComponent_list)
else:
delay_names = [x.__class__.__name__ for x in self.DelayComponent_list]
- if cutoff_component in delay_names:
- idx = delay_names.index(cutoff_component)
- if include_last:
- idx += 1
- else:
- raise KeyError("No delay component named '%s'." % cutoff_component)
+ if cutoff_component not in delay_names:
+ raise KeyError(f"No delay component named '{cutoff_component}'.")
+ idx = delay_names.index(cutoff_component)
+ if include_last:
+ idx += 1
# Do NOT cycle through delay_funcs - cycle through components until cutoff
for dc in self.DelayComponent_list[:idx]:
for df in dc.delay_funcs_component:
@@ -1260,31 +1325,35 @@ def phase(self, toas, abs_phase=None):
# abs_phase defaults to True if AbsPhase is in the model, otherwise to
# False. Of course, if you manually set it, it will use that setting.
if abs_phase is None:
- if "AbsPhase" in list(self.components.keys()):
- abs_phase = True
- else:
- abs_phase = False
- # If the absolute phase flag is on, use the TZR parameters to compute
- # the absolute phase.
- if abs_phase:
- if "AbsPhase" not in list(self.components.keys()):
- # if no absolute phase (TZRMJD), add the component to the model and calculate it
- from pint.models import absolute_phase
-
- self.add_component(absolute_phase.AbsPhase(), validate=False)
- self.make_TZR_toa(
- toas
- ) # TODO:needs timfile to get all toas, but model doesn't have access to timfile. different place for this?
- self.validate()
- tz_toa = self.get_TZR_toa(toas)
- tz_delay = self.delay(tz_toa)
- tz_phase = Phase(np.zeros(len(toas.table)), np.zeros(len(toas.table)))
- for pf in self.phase_funcs:
- tz_phase += Phase(pf(tz_toa, tz_delay))
- return phase - tz_phase
- else:
+ abs_phase = "AbsPhase" in list(self.components.keys())
+
+ # This function gets called in `Residuals.calc_phase_resids()` with `abs_phase=True`
+ # by default. Hence, this branch is not run by default.
+ if not abs_phase:
return phase
+ if "AbsPhase" not in list(self.components.keys()):
+ log.info("Creating a TZR TOA (AbsPhase) using the given TOAs object.")
+
+ # if no absolute phase (TZRMJD), add the component to the model and calculate it
+ self.add_tzr_toa(toas)
+
+ tz_toa = self.get_TZR_toa(toas)
+ tz_delay = self.delay(tz_toa)
+ tz_phase = Phase(np.zeros(len(toas.table)), np.zeros(len(toas.table)))
+ for pf in self.phase_funcs:
+ tz_phase += Phase(pf(tz_toa, tz_delay))
+ return phase - tz_phase
+
+ def add_tzr_toa(self, toas):
+ """Create a TZR TOA for the given TOAs object and add it to
+ the timing model. This corresponds to TOA closest to the PEPOCH."""
+ from pint.models.absolute_phase import AbsPhase
+
+ self.add_component(AbsPhase(), validate=False)
+ self.make_TZR_toa(toas)
+ self.validate()
+
def total_dm(self, toas):
"""Calculate dispersion measure from all the dispersion type of components."""
# Here we assume the unit would be the same for all the dm value function.
@@ -1398,8 +1467,8 @@ def noise_model_dimensions(self, toas):
"""Number of basis functions for each noise model component.
Returns a dictionary of correlated-noise components in the noise
- model. Each entry contains a tuple (offset, size) where size is the
- number of basis funtions for the component, and offset is their
+ model. Each entry contains a tuple (offset, size) where size is the
+ number of basis functions for the component, and offset is their
starting location in the design matrix and weights vector.
"""
result = {}
@@ -1463,13 +1532,9 @@ def jump_flags_to_params(self, toas):
if tjv in tim_jump_values:
log.info(f"JUMP -tim_jump {tjv} already exists")
tim_jump_values.remove(tjv)
- if used_indices:
- num = max(used_indices) + 1
- else:
- num = 1
-
+ num = max(used_indices) + 1 if used_indices else 1
if not tim_jump_values:
- log.info(f"All tim_jump values have corresponding JUMPs")
+ log.info("All tim_jump values have corresponding JUMPs")
return
# FIXME: arrange for these to be in a sensible order (might not be integers
@@ -1515,7 +1580,7 @@ def delete_jump_and_flags(self, toa_table, jump_num):
Specifies the index of the jump to be deleted.
"""
# remove jump of specified index
- self.remove_param("JUMP" + str(jump_num))
+ self.remove_param(f"JUMP{jump_num}")
# remove jump flags from selected TOA tables
if toa_table is not None:
@@ -1580,7 +1645,7 @@ def d_phase_d_toa(self, toas, sample_step=None):
if sample_step is None:
pulse_period = 1.0 / (self.F0.quantity)
sample_step = pulse_period * 2
- # Note that sample_dt is applied cumulatively, so this evaulates phase at TOA-dt and TOA+dt
+ # Note that sample_dt is applied cumulatively, so this evaluates phase at TOA-dt and TOA+dt
sample_dt = [-sample_step, 2 * sample_step]
sample_phase = []
@@ -1654,7 +1719,6 @@ def d_phase_d_param(self, toas, delay, param):
# d_Phase2/d_delay*d_delay/d_param
# = (d_Phase1/d_delay + d_Phase2/d_delay) *
# d_delay_d_param
-
d_delay_d_p = self.d_delay_d_param(toas, param)
dpdd_result = np.longdouble(np.zeros(toas.ntoas)) / u.second
for dpddf in self.d_phase_d_delay_funcs:
@@ -1669,8 +1733,8 @@ def d_delay_d_param(self, toas, param, acc_delay=None):
delay_derivs = self.delay_deriv_funcs
if param not in list(delay_derivs.keys()):
raise AttributeError(
- "Derivative function for '%s' is not provided"
- " or not registered. " % param
+ "Derivative function for '{param}' is not provided"
+ " or not registered. "
)
for df in delay_derivs[param]:
result += df(toas, param, acc_delay).to(
@@ -1687,10 +1751,7 @@ def d_phase_d_param_num(self, toas, param, step=1e-2):
par = getattr(self, param)
ori_value = par.value
unit = par.units
- if ori_value == 0:
- h = 1.0 * step
- else:
- h = ori_value * step
+ h = 1.0 * step if ori_value == 0 else ori_value * step
parv = [par.value - h, par.value + h]
phase_i = (
@@ -1721,13 +1782,10 @@ def d_delay_d_param_num(self, toas, param, step=1e-2):
ori_value = par.value
if ori_value is None:
# A parameter did not get to use in the model
- log.warning("Parameter '%s' is not used by timing model." % param)
+ log.warning(f"Parameter '{param}' is not used by timing model.")
return np.zeros(toas.ntoas) * (u.second / par.units)
unit = par.units
- if ori_value == 0:
- h = 1.0 * step
- else:
- h = ori_value * step
+ h = 1.0 * step if ori_value == 0 else ori_value * step
parv = [par.value - h, par.value + h]
delay = np.zeros((toas.ntoas, 2))
for ii, val in enumerate(parv):
@@ -1747,7 +1805,7 @@ def d_dm_d_param(self, data, param):
result = np.zeros(len(data)) << (u.pc / u.cm**3 / par.units)
dm_df = self.dm_derivs.get(param, None)
if dm_df is None:
- if param not in self.params: # Maybe add differentitable params
+ if param not in self.params: # Maybe add differentiable params
raise AttributeError(f"Parameter {param} does not exist")
else:
return result
@@ -1780,6 +1838,7 @@ def designmatrix(self, toas, acc_delay=None, incfrozen=False, incoffset=True):
Whether to include frozen parameters in the design matrix
incoffset : bool
Whether to include the constant offset in the design matrix
+ This option is ignored if a `PhaseOffset` component is present.
Returns
-------
@@ -1807,6 +1866,8 @@ def designmatrix(self, toas, acc_delay=None, incfrozen=False, incoffset=True):
# The entries for any unfrozen noise parameters will not be
# included in the design matrix as they are not well-defined.
+ incoffset = incoffset and "PhaseOffset" not in self.components
+
params = ["Offset"] if incoffset else []
params += [
par
@@ -1961,10 +2022,7 @@ def compare(
log.debug("Check verbosity - only warnings/info will be displayed")
othermodel = copy.deepcopy(othermodel)
- if (
- "POSEPOCH" in self.params_ordered
- and "POSEPOCH" in othermodel.params_ordered
- ):
+ if "POSEPOCH" in self.params and "POSEPOCH" in othermodel.params:
if (
self.POSEPOCH.value is not None
and othermodel.POSEPOCH.value is not None
@@ -1975,7 +2033,7 @@ def compare(
% (other_model_name, model_name)
)
othermodel.change_posepoch(self.POSEPOCH.value)
- if "PEPOCH" in self.params_ordered and "PEPOCH" in othermodel.params_ordered:
+ if "PEPOCH" in self.params and "PEPOCH" in othermodel.params:
if (
self.PEPOCH.value is not None
and self.PEPOCH.value != othermodel.PEPOCH.value
@@ -1984,7 +2042,7 @@ def compare(
"Updating PEPOCH in %s to match %s" % (other_model_name, model_name)
)
othermodel.change_pepoch(self.PEPOCH.value)
- if "DMEPOCH" in self.params_ordered and "DMEPOCH" in othermodel.params_ordered:
+ if "DMEPOCH" in self.params and "DMEPOCH" in othermodel.params:
if (
self.DMEPOCH.value is not None
and self.DMEPOCH.value != othermodel.DMEPOCH.value
@@ -2019,7 +2077,7 @@ def compare(
f"{model_name} is in ECL({self.ECL.value}) coordinates but {other_model_name} is in ICRS coordinates and convertcoordinates=False"
)
- for pn in self.params_ordered:
+ for pn in self.params:
par = getattr(self, pn)
if par.value is None:
continue
@@ -2246,8 +2304,8 @@ def compare(
)
# Now print any parameters in othermodel that were missing in self.
- mypn = self.params_ordered
- for opn in othermodel.params_ordered:
+ mypn = self.params
+ for opn in othermodel.params:
if opn in mypn and getattr(self, opn).value is not None:
continue
if nodmx and opn.startswith("DMX"):
@@ -3237,6 +3295,26 @@ def _param_alias_map(self):
alias[als] = tp
return alias
+ @lazyproperty
+ def _param_unit_map(self):
+ """A dictionary to map parameter names to their units
+
+ This excludes prefix parameters and aliases. Use :func:`param_to_unit` to handle those.
+ """
+ units = {}
+ for k, cp in self.components.items():
+ for p in cp.params:
+ if p in units.keys():
+ if units[p] != getattr(cp, p).units:
+ raise TimingModelError(
+ f"Units of parameter '{p}' in component '{cp}' ({getattr(cp, p).units}) do not match those of existing parameter ({units[p]})"
+ )
+ units[p] = getattr(cp, p).units
+ tm = TimingModel()
+ for tp in tm.params:
+ units[p] = getattr(tm, tp).units
+ return units
+
@lazyproperty
def repeatable_param(self):
"""Return the repeatable parameter map."""
@@ -3340,8 +3418,8 @@ def alias_to_pint_param(self, alias):
"""Translate a alias to a PINT parameter name.
This is a wrapper function over the property ``_param_alias_map``. It
- also handles the indexed parameters (e.g., `pint.models.parameter.prefixParameter`
- and `pint.models.parameter.maskParameter`) with and index beyond currently
+ also handles indexed parameters (e.g., `pint.models.parameter.prefixParameter`
+ and `pint.models.parameter.maskParameter`) with an index beyond those currently
initialized.
Parameters
@@ -3425,6 +3503,35 @@ def alias_to_pint_param(self, alias):
)
return pint_par, first_init_par
+ def param_to_unit(self, name):
+ """Return the unit associated with a parameter
+
+ This is a wrapper function over the property ``_param_unit_map``. It
+ also handles aliases and indexed parameters (e.g., `pint.models.parameter.prefixParameter`
+ and `pint.models.parameter.maskParameter`) with an index beyond those currently
+ initialized.
+
+ This can be used without an existing :class:`~pint.models.TimingModel`.
+
+ Parameters
+ ----------
+ name : str
+ Name of PINT parameter or alias
+
+ Returns
+ -------
+ astropy.u.Unit
+ """
+ pintname, firstname = self.alias_to_pint_param(name)
+ if pintname == firstname:
+ # not a prefix parameter
+ return self._param_unit_map[pintname]
+ prefix, idx_str, idx = split_prefixed_name(pintname)
+ component = self.param_component_map[firstname][0]
+ if getattr(self.components[component], firstname).unit_template is None:
+ return self._param_unit_map[firstname]
+ return u.Unit(getattr(self.components[component], firstname).unit_template(idx))
+
class TimingModelError(ValueError):
"""Generic base class for timing model errors."""
diff --git a/src/pint/models/troposphere_delay.py b/src/pint/models/troposphere_delay.py
index c6f59194a..911f2ad12 100644
--- a/src/pint/models/troposphere_delay.py
+++ b/src/pint/models/troposphere_delay.py
@@ -164,8 +164,7 @@ def troposphere_delay(self, toas, acc_delay=None):
# exclude non topocentric observations
if not isinstance(obsobj, TopoObs):
log.debug(
- "Skipping Troposphere delay for non Topocentric TOA: %s"
- % obsobj.name
+ f"Skipping Troposphere delay for non Topocentric TOA: {obsobj.name}"
)
continue
@@ -200,13 +199,13 @@ def _validate_altitudes(self, alt, obs=""):
isValid = np.logical_and(isPositive, isLessThan90)
# now make corrections to alt based on the valid status
- # if not valid, make them appear at the zenith to make the math sensical
+ # if not valid, make them appear at the zenith to make the math sensible
if not np.all(isValid):
# it's probably helpful to count how many are invalid
numInvalid = len(isValid) - np.count_nonzero(isValid)
message = "Invalid altitude calculated for %i TOAS" % numInvalid
if obs:
- message += " from observatory " + obs
+ message += f" from observatory {obs}"
log.warning(message)
# now correct the values
@@ -238,8 +237,7 @@ def pressure_from_altitude(self, H):
if gph > 11 * u.km:
log.warning("Pressure approximation invalid for elevations above 11 km")
T = 288.15 - 0.0065 * H.to(u.m).value # temperature lapse
- P = 101.325 * (288.15 / T) ** -5.25575 * u.kPa
- return P
+ return 101.325 * (288.15 / T) ** -5.25575 * u.kPa
def zenith_delay(self, lat, H):
"""Calculate the hydrostatic zenith delay"""
@@ -268,7 +266,7 @@ def _find_latitude_index(self, lat):
if absLat <= self.LAT[lInd]:
return lInd - 1
# else this is an invalid latitude... huh?
- raise ValueError("Invaid latitude: %s must be between -90 and 90 degrees" % lat)
+ raise ValueError(f"Invaid latitude: {lat} must be between -90 and 90 degrees")
def mapping_function(self, alt, lat, H, mjd):
"""this implements the Niell mapping function for hydrostatic delays"""
@@ -355,22 +353,16 @@ def _get_year_fraction_slow(self, mjd, lat):
but it's more slow because of the looping
"""
- seasonOffset = 0.0
- if lat < 0:
- seasonOffset = 0.5
-
- yearFraction = np.array(
+ seasonOffset = 0.5 if lat < 0 else 0.0
+ return np.array(
[(i.jyear + seasonOffset + self.DOY_OFFSET / 365.25) % 1.0 for i in mjd]
)
- return yearFraction
def _get_year_fraction_fast(self, tdbld, lat):
"""
use numpy array arithmetic to calculate the year fraction more quickly
"""
- seasonOffset = 0.0
- if lat < 0:
- seasonOffset = 0.5
+ seasonOffset = 0.5 if lat < 0 else 0.0
return np.mod(
2000.0 + (tdbld - 51544.5 + self.DOY_OFFSET) / (365.25) + seasonOffset, 1.0
)
diff --git a/src/pint/models/wave.py b/src/pint/models/wave.py
index ad8b80acb..7abe4ed79 100644
--- a/src/pint/models/wave.py
+++ b/src/pint/models/wave.py
@@ -79,7 +79,7 @@ def validate(self):
)
self.wave_terms.sort()
wave_in_order = list(range(1, max(self.wave_terms) + 1))
- if not self.wave_terms == wave_in_order:
+ if self.wave_terms != wave_in_order:
diff = list(set(wave_in_order) - set(self.wave_terms))
raise MissingParameter("Wave", "WAVE%d" % diff[0])
@@ -115,5 +115,4 @@ def wave_phase(self, toas, delays):
times += wave_a * np.sin(wave_phase)
times += wave_b * np.cos(wave_phase)
- phase = ((times) * self._parent.F0.quantity).to(u.dimensionless_unscaled)
- return phase
+ return ((times) * self._parent.F0.quantity).to(u.dimensionless_unscaled)
diff --git a/src/pint/modelutils.py b/src/pint/modelutils.py
index aab120033..0a2842788 100644
--- a/src/pint/modelutils.py
+++ b/src/pint/modelutils.py
@@ -21,34 +21,33 @@ def model_ecliptic_to_equatorial(model, force=False):
new model with AstrometryEquatorial component
"""
- if not ("AstrometryEquatorial" in model.components) or force:
+ if "AstrometryEquatorial" not in model.components or force:
if "AstrometryEquatorial" in model.components:
log.warning(
"Equatorial coordinates already present but re-calculating anyway"
)
- if "AstrometryEcliptic" in model.components:
- c = model.coords_as_ICRS()
- a = AstrometryEquatorial()
+ if "AstrometryEcliptic" not in model.components:
+ raise AttributeError(
+ "Requested conversion to equatorial coordinates, but no alternate coordinates found"
+ )
- a.POSEPOCH = model.POSEPOCH
- a.PX = model.PX
+ c = model.coords_as_ICRS()
+ a = AstrometryEquatorial()
- a.RAJ.quantity = c.ra
- a.DECJ.quantity = c.dec
- a.PMRA.quantity = c.pm_ra_cosdec
- a.PMDEC.quantity = c.pm_dec
+ a.POSEPOCH = model.POSEPOCH
+ a.PX = model.PX
- model.add_component(a)
- model.remove_component("AstrometryEcliptic")
+ a.RAJ.quantity = c.ra
+ a.DECJ.quantity = c.dec
+ a.PMRA.quantity = c.pm_ra_cosdec
+ a.PMDEC.quantity = c.pm_dec
- model.setup()
- model.validate()
+ model.remove_component("AstrometryEcliptic")
+ model.add_component(a)
- else:
- raise AttributeError(
- "Requested conversion to equatorial coordinates, but no alternate coordinates found"
- )
+ model.setup()
+ model.validate()
else:
log.warning("Equatorial coordinates already present; not re-calculating")
@@ -72,33 +71,32 @@ def model_equatorial_to_ecliptic(model, force=False):
new model with AstrometryEcliptic component
"""
- if not ("AstrometryEcliptic" in model.components) or force:
+ if "AstrometryEcliptic" not in model.components or force:
if "AstrometryEcliptic" in model.components:
log.warning(
"Ecliptic coordinates already present but re-calculating anyway"
)
- if "AstrometryEquatorial" in model.components:
- c = model.coords_as_ECL()
- a = AstrometryEcliptic()
+ if "AstrometryEquatorial" not in model.components:
+ raise AttributeError(
+ "Requested conversion to ecliptic coordinates, but no alternate coordinates found"
+ )
- a.POSEPOCH = model.POSEPOCH
- a.PX = model.PX
+ c = model.coords_as_ECL()
+ a = AstrometryEcliptic()
- a.ELONG.quantity = c.lon
- a.ELAT.quantity = c.lat
- a.PMELONG.quantity = c.pm_lon_coslat
- a.PMELAT.quantity = c.pm_lat
+ a.POSEPOCH = model.POSEPOCH
+ a.PX = model.PX
- model.add_component(a)
- model.remove_component("AstrometryEquatorial")
+ a.ELONG.quantity = c.lon
+ a.ELAT.quantity = c.lat
+ a.PMELONG.quantity = c.pm_lon_coslat
+ a.PMELAT.quantity = c.pm_lat
- model.setup()
- model.validate()
+ model.remove_component("AstrometryEquatorial")
+ model.add_component(a)
- else:
- raise AttributeError(
- "Requested conversion to ecliptic coordinates, but no alternate coordinates found"
- )
+ model.setup()
+ model.validate()
else:
log.warning("Ecliptic coordinates already present; not re-calculating")
diff --git a/src/pint/observatory/__init__.py b/src/pint/observatory/__init__.py
index 187befa74..7887b8bf9 100644
--- a/src/pint/observatory/__init__.py
+++ b/src/pint/observatory/__init__.py
@@ -22,9 +22,7 @@
"""
import os
-import sys
import textwrap
-import warnings
from collections import defaultdict
from io import StringIO
from pathlib import Path
@@ -156,9 +154,7 @@ def __new__(cls, name, *args, **kwargs):
raise ValueError(
f"Observatory {name.lower} already present and overwrite=False"
)
- log.warning(
- "Observatory '%s' already present; overwriting..." % name.lower()
- )
+ log.warning(f"Observatory '{name.lower()}' already present; overwriting...")
cls._register(obs, name)
return obs
@@ -296,16 +292,16 @@ def get(cls, name):
site_astropy = astropy.coordinates.EarthLocation.of_site(name)
except astropy.coordinates.errors.UnknownSiteException as e:
# turn it into the same error type as PINT would have returned
- raise KeyError("Observatory name '%s' is not defined" % name) from e
+ raise KeyError(f"Observatory name '{name}' is not defined") from e
# we need to import this here rather than up-top because of circular import issues
from pint.observatory.topo_obs import TopoObs
+ # add in metadata from astropy
obs = TopoObs(
name,
location=site_astropy,
- # add in metadata from astropy
- origin="astropy: '%s'" % site_astropy.info.meta["source"],
+ origin=f"""astropy: '{site_astropy.info.meta["source"]}'""",
)
# add to registry
cls._register(obs, name)
@@ -438,7 +434,7 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
)
return self._get_TDB_ephem(t, ephem)
else:
- raise ValueError("Unknown method '%s'." % method)
+ raise ValueError(f"Unknown method '{method}'.")
def _get_TDB_default(self, t, ephem):
return t.tdb
@@ -525,10 +521,10 @@ def compare_t2_observatories_dat(t2dir=None):
"""
if t2dir is None:
t2dir = os.getenv("TEMPO2")
- if t2dir is None:
- raise ValueError(
- "TEMPO2 directory not provided and TEMPO2 environment variable not set"
- )
+ if t2dir is None:
+ raise ValueError(
+ "TEMPO2 directory not provided and TEMPO2 environment variable not set"
+ )
filename = os.path.join(t2dir, "observatory", "observatories.dat")
report = defaultdict(list)
@@ -542,7 +538,7 @@ def compare_t2_observatories_dat(t2dir=None):
full_name, short_name = full_name.lower(), short_name.lower()
topo_obs_entry = textwrap.dedent(
f"""
- "{full_name}": {
+ "{full_name}": {{
"aliases": [
"{short_name}"
],
@@ -551,7 +547,7 @@ def compare_t2_observatories_dat(t2dir=None):
{y},
{z}
]
- }
+ }}
"""
)
try:
@@ -624,7 +620,7 @@ def compare_tempo_obsys_dat(tempodir=None):
report = defaultdict(list)
with open(filename) as f:
- for line in f.readlines():
+ for line in f:
if line.strip().startswith("#"):
continue
try:
@@ -663,7 +659,7 @@ def convert_angle(x):
name = obsnam.replace(" ", "_")
topo_obs_entry = textwrap.dedent(
f"""
- "{name}": {
+ "{name}": {{
"itrf_xyz": [
{x},
{y},
@@ -671,7 +667,7 @@ def convert_angle(x):
],
"tempo_code": "{tempo_code}",
"itoa_code": "{itoa_code}"
- }
+ }}
"""
)
try:
diff --git a/src/pint/observatory/clock_file.py b/src/pint/observatory/clock_file.py
index f66d02bcb..2d3fe0667 100644
--- a/src/pint/observatory/clock_file.py
+++ b/src/pint/observatory/clock_file.py
@@ -125,7 +125,7 @@ def read(cls, filename, format="tempo", **kwargs):
if format in cls._formats:
return cls._formats[format](filename, **kwargs)
else:
- raise ValueError("clock file format '%s' not defined" % format)
+ raise ValueError(f"clock file format '{format}' not defined")
@property
def time(self):
@@ -182,10 +182,7 @@ def evaluate(self, t, limits="warn"):
def last_correction_mjd(self):
"""Last MJD for which corrections are available."""
- if len(self.time) == 0:
- return -np.inf
- else:
- return self.time[-1].mjd
+ return -np.inf if len(self.time) == 0 else self.time[-1].mjd
@staticmethod
def merge(clocks, *, trim=True):
@@ -249,8 +246,8 @@ def merge(clocks, *, trim=True):
]
corr += this_corr
if trim:
- b = max([c._time.mjd[0] for c in clocks])
- e = min([c._time.mjd[-1] for c in clocks])
+ b = max(c._time.mjd[0] for c in clocks)
+ e = min(c._time.mjd[-1] for c in clocks)
il = np.searchsorted(times.mjd, b)
ir = np.searchsorted(times.mjd, e, side="right")
times = times[il:ir]
@@ -315,21 +312,21 @@ def write_tempo_clock_file(self, filename, obscode, extra_comment=None):
)
mjds = self.time.mjd
corr = self.clock.to_value(u.us)
- comments = self.comments if self.comments else [""] * len(self.clock)
+ comments = self.comments or [""] * len(self.clock)
# TEMPO writes microseconds
- if extra_comment is not None:
- if self.leading_comment is not None:
- leading_comment = extra_comment.rstrip() + "\n" + self.leading_comment
- else:
- leading_comment = extra_comment.rstrip()
- else:
+ if extra_comment is None:
leading_comment = self.leading_comment
+ elif self.leading_comment is not None:
+ leading_comment = extra_comment.rstrip() + "\n" + self.leading_comment
+ else:
+ leading_comment = extra_comment.rstrip()
with open_or_use(filename, "wt") as f:
f.write(tempo_standard_header)
if leading_comment is not None:
f.write(leading_comment.strip())
f.write("\n")
# Do not use EECO-REF column as TEMPO does a weird subtraction thing
+ # sourcery skip: hoist-statement-from-loop
for mjd, corr, comment in zip(mjds, corr, comments):
# 0:9 for MJD
# 9:21 for clkcorr1 (do not use)
@@ -380,20 +377,19 @@ def write_tempo2_clock_file(self, filename, hdrline=None, extra_comment=None):
hdrline = self.header
if not hdrline.startswith("#"):
raise ValueError(f"Header line must start with #: {hdrline!r}")
- if extra_comment is not None:
- if self.leading_comment is not None:
- leading_comment = extra_comment.rstrip() + "\n" + self.leading_comment
- else:
- leading_comment = extra_comment.rstrip()
- else:
+ if extra_comment is None:
leading_comment = self.leading_comment
+ elif self.leading_comment is not None:
+ leading_comment = extra_comment.rstrip() + "\n" + self.leading_comment
+ else:
+ leading_comment = extra_comment.rstrip()
with open_or_use(filename, "wt") as f:
f.write(hdrline.rstrip())
f.write("\n")
if leading_comment is not None:
f.write(leading_comment.rstrip())
f.write("\n")
- comments = self.comments if self.comments else [""] * len(self.time)
+ comments = self.comments or [""] * len(self.time)
for mjd, corr, comment in zip(
self.time.mjd, self.clock.to_value(u.s), comments
@@ -512,7 +508,7 @@ def add_comment(s):
# Anything else on the line is a comment too
add_comment(m.group(3))
clk = np.array(clk)
- except (FileNotFoundError, OSError):
+ except OSError:
raise NoClockCorrections(
f"TEMPO2-style clock correction file {filename} not found"
)
@@ -687,7 +683,7 @@ def add_comment(s):
# Parse MJD
try:
- mjd = float(l[0:9])
+ mjd = float(l[:9])
# allow mjd=0 to pass, since that is often used
# for effectively null clock files
if (mjd < 39000 and mjd != 0) or mjd > 100000:
@@ -741,7 +737,7 @@ def add_comment(s):
clkcorrs.append(clkcorr2 - clkcorr1)
comments.append(None)
add_comment(l[50:])
- except (FileNotFoundError, OSError):
+ except OSError:
raise NoClockCorrections(
f"TEMPO-style clock correction file {filename} "
f"for site {obscode} not found"
@@ -830,16 +826,10 @@ def update(self):
f = get_clock_correction_file(
self.filename, url_base=self.url_base, url_mirrors=self.url_mirrors
)
- if f == self.f and f.stat().st_mtime == mtime:
- # Nothing changed
- pass
- else:
+ if f != self.f or f.stat().st_mtime != mtime:
self.f = f
h = compute_hash(f)
- if h == self.hash:
- # Nothing changed but we got it from the Net
- pass
- else:
+ if h != self.hash:
# Actual new data (probably)!
self.hash = h
self.clock_file = ClockFile.read(
diff --git a/src/pint/observatory/global_clock_corrections.py b/src/pint/observatory/global_clock_corrections.py
index 5c907fb69..755895636 100644
--- a/src/pint/observatory/global_clock_corrections.py
+++ b/src/pint/observatory/global_clock_corrections.py
@@ -78,9 +78,8 @@ def get_file(
url_base = global_clock_correction_url_base
if url_mirrors is None:
url_mirrors = global_clock_correction_url_mirrors
- else:
- if url_mirrors is None:
- url_mirrors = [url_base]
+ elif url_mirrors is None:
+ url_mirrors = [url_base]
local_file = None
remote_url = url_base + name
mirror_urls = [u + name for u in url_mirrors]
@@ -89,10 +88,10 @@ def get_file(
try:
local_file = Path(download_file(remote_url, cache=True, sources=[]))
log.trace(f"file {remote_url} found in cache at path: {local_file}")
- except KeyError:
+ except KeyError as e:
log.trace(f"file {remote_url} not found in cache")
if download_policy == "never":
- raise FileNotFoundError(name)
+ raise FileNotFoundError(name) from e
if download_policy == "if_missing" and local_file is not None:
log.trace(
@@ -133,14 +132,13 @@ def get_file(
try:
return Path(download_file(remote_url, cache="update", sources=mirror_urls))
except IOError as e:
- if download_policy == "if_expired" and local_file is not None:
- warn(
- f"File {name} should be downloaded but {local_file} is being used "
- f"because an error occurred: {e}"
- )
- return local_file
- else:
+ if download_policy != "if_expired" or local_file is None:
raise
+ warn(
+ f"File {name} should be downloaded but {local_file} is being used "
+ f"because an error occurred: {e}"
+ )
+ return local_file
IndexEntry = collections.namedtuple(
@@ -175,10 +173,7 @@ def __init__(self, download_policy="if_expired", url_base=None, url_mirrors=None
if not line:
continue
e = line.split(maxsplit=3)
- if e[2] == "---":
- date = None
- else:
- date = Time(e[2], format="iso")
+ date = None if e[2] == "---" else Time(e[2], format="iso")
t = IndexEntry(
file=e[0],
update_interval_days=float(e[1]),
diff --git a/src/pint/observatory/satellite_obs.py b/src/pint/observatory/satellite_obs.py
index 3d7117b33..2021c0b48 100644
--- a/src/pint/observatory/satellite_obs.py
+++ b/src/pint/observatory/satellite_obs.py
@@ -68,7 +68,7 @@ def load_Fermi_FT2(ft2_filename):
# Otherwise, compute velocities by differentiation because FT2 does not have velocities
# This is not the best way. Should fit an orbit and determine velocity from that.
dt = mjds_TT[1] - mjds_TT[0]
- log.info("FT2 spacing is " + str(dt.to(u.s)))
+ log.info(f"FT2 spacing is {str(dt.to(u.s))}")
# Use "spacing" argument for gradient to handle nonuniform entries
tt = mjds_TT.to(u.s).value
Vx = np.gradient(X.value, tt) * u.m / u.s
@@ -79,12 +79,11 @@ def load_Fermi_FT2(ft2_filename):
mjds_TT.min(), mjds_TT.max()
)
)
- FT2_table = Table(
+ return Table(
[mjds_TT, X, Y, Z, Vx, Vy, Vz],
names=("MJD_TT", "X", "Y", "Z", "Vx", "Vy", "Vz"),
meta={"name": "FT2"},
)
- return FT2_table
def load_FPorbit(orbit_filename):
@@ -119,14 +118,14 @@ def load_FPorbit(orbit_filename):
# TIMEREF should be 'LOCAL', since no delays are applied
- if not "TIMESYS" in FPorbit_hdr:
+ if "TIMESYS" not in FPorbit_hdr:
log.warning("Keyword TIMESYS is missing. Assuming TT")
timesys = "TT"
else:
timesys = FPorbit_hdr["TIMESYS"]
log.debug("FPorbit TIMESYS {0}".format(timesys))
- if not "TIMEREF" in FPorbit_hdr:
+ if "TIMEREF" not in FPorbit_hdr:
log.warning("Keyword TIMESYS is missing. Assuming TT")
timeref = "LOCAL"
else:
@@ -235,12 +234,11 @@ def load_nustar_orbit(orb_filename):
mjds_TT.min(), mjds_TT.max()
)
)
- orb_table = Table(
+ return Table(
[mjds_TT, X, Y, Z, Vx, Vy, Vz],
names=("MJD_TT", "X", "Y", "Z", "Vx", "Vy", "Vz"),
meta={"name": "orb"},
)
- return orb_table
def load_orbit(obs_name, orb_filename):
@@ -262,10 +260,8 @@ def load_orbit(obs_name, orb_filename):
if str(orb_filename).startswith("@"):
# Read multiple orbit files names
- orb_list = []
fnames = [ll.strip() for ll in open(orb_filename[1:]).readlines()]
- for fn in fnames:
- orb_list.append(load_orbit(obs_name, fn))
+ orb_list = [load_orbit(obs_name, fn) for fn in fnames]
full_orb = vstack(orb_list)
# Make sure full table is sorted
full_orb.sort("MJD_TT")
@@ -283,7 +279,7 @@ def load_orbit(obs_name, orb_filename):
elif "nustar" in lower_name:
return load_nustar_orbit(orb_filename)
else:
- raise ValueError("Unrecognized satellite observatory %s." % (obs_name))
+ raise ValueError(f"Unrecognized satellite observatory {obs_name}.")
class SatelliteObs(SpecialLocation):
@@ -430,8 +426,7 @@ def posvel_gcrs(self, t, ephem=None):
np.array([self.Vx(t.tt.mjd), self.Vy(t.tt.mjd), self.Vz(t.tt.mjd)])
* self.FT2["Vx"].unit
)
- sat_posvel = PosVel(sat_pos_geo, sat_vel_geo, origin="earth", obj=self.name)
- return sat_posvel
+ return PosVel(sat_pos_geo, sat_vel_geo, origin="earth", obj=self.name)
def get_satellite_observatory(name, ft2name, **kwargs):
diff --git a/src/pint/observatory/special_locations.py b/src/pint/observatory/special_locations.py
index 950cbc5aa..b5b56047c 100644
--- a/src/pint/observatory/special_locations.py
+++ b/src/pint/observatory/special_locations.py
@@ -233,8 +233,7 @@ def posvel_gcrs(self, t, group, ephem=None):
pos_geo = self.get_gcrs(t, group, ephem=None)
- stl_posvel = PosVel(pos_geo, vel_geo, origin="earth", obj="spacecraft")
- return stl_posvel
+ return PosVel(pos_geo, vel_geo, origin="earth", obj="spacecraft")
def posvel(self, t, ephem, group=None):
if group is None:
diff --git a/src/pint/observatory/topo_obs.py b/src/pint/observatory/topo_obs.py
index 915616140..b40e5ccc5 100644
--- a/src/pint/observatory/topo_obs.py
+++ b/src/pint/observatory/topo_obs.py
@@ -177,8 +177,7 @@ def __init__(
]
if sum(input_values) == 0:
raise ValueError(
- "EarthLocation, ITRF coordinates, or lat/lon/height are required for observatory '%s'"
- % name
+ f"EarthLocation, ITRF coordinates, or lat/lon/height are required for observatory '{name}'"
)
if sum(input_values) > 1:
raise ValueError(
@@ -198,7 +197,7 @@ def __init__(
# Check for correct array dims
if xyz.shape != (3,):
raise ValueError(
- "Incorrect coordinate dimensions for observatory '%s'" % (name)
+ f"Incorrect coordinate dimensions for observatory '{name}'"
)
# Convert to astropy EarthLocation, ensuring use of ITRF geocentric coordinates
self.location = EarthLocation.from_geocentric(*xyz)
@@ -216,7 +215,7 @@ def __init__(
# If using TEMPO time.dat we need to know the 1-char tempo-style
# observatory code.
if clock_fmt == "tempo" and clock_file == "time.dat" and tempo_code is None:
- raise ValueError("No tempo_code set for observatory '%s'" % name)
+ raise ValueError(f"No tempo_code set for observatory '{name}'")
# GPS corrections
self.include_gps = include_gps
diff --git a/src/pint/orbital/kepler.py b/src/pint/orbital/kepler.py
index d47b91249..9331d65cf 100644
--- a/src/pint/orbital/kepler.py
+++ b/src/pint/orbital/kepler.py
@@ -111,7 +111,7 @@ class Kepler2DParameters(
Parameters
----------
a : float
- semimajor axis
+ semi-major axis
pb : float
binary period
eps1 : float
@@ -127,8 +127,8 @@ class Kepler2DParameters(
def kepler_2d(params, t):
"""Position and velocity of a particle in a Kepler orbit.
- The orbit has semimajor axis a, period pb, and eccentricity
- paramerized by eps1=e*sin(om) and eps2=e*cos(om), and the
+ The orbit has semi-major axis a, period pb, and eccentricity
+ parametrized by eps1=e*sin(om) and eps2=e*cos(om), and the
particle is on the x axis at time t0, while the values
are computed for time t.
@@ -366,7 +366,7 @@ class Kepler3DParameters(
Parameters
----------
a : float
- semimajor axis
+ semi-major axis
pb : float
binary period
eps1 : float
diff --git a/src/pint/phase.py b/src/pint/phase.py
index 39d3cde21..b876abe73 100644
--- a/src/pint/phase.py
+++ b/src/pint/phase.py
@@ -51,21 +51,22 @@ def __new__(cls, arg1, arg2=None):
pulse phase object with arrays of dimensionless :class:`~astropy.units.Quantity`
objects as the ``int`` and ``frac`` parts
"""
- if not hasattr(arg1, "unit"):
- arg1 = u.Quantity(arg1)
- else:
- # This will raise an exception if the argument has any unit not convertible to Unit(dimensionless)
- arg1 = arg1.to(u.dimensionless_unscaled)
+ arg1 = (
+ arg1.to(u.dimensionless_unscaled)
+ if hasattr(arg1, "unit")
+ else u.Quantity(arg1)
+ )
# If arg is scalar, convert to an array of length 1
if arg1.shape == ():
arg1 = arg1.reshape((1,))
if arg2 is None:
ff, ii = numpy.modf(arg1)
else:
- if not hasattr(arg2, "unit"):
- arg2 = u.Quantity(arg2)
- else:
- arg2 = arg2.to(u.dimensionless_unscaled)
+ arg2 = (
+ arg2.to(u.dimensionless_unscaled)
+ if hasattr(arg2, "unit")
+ else u.Quantity(arg2)
+ )
if arg2.shape == ():
arg2 = arg2.reshape((1,))
arg1S = numpy.modf(arg1)
diff --git a/src/pint/pint_matrix.py b/src/pint/pint_matrix.py
index 321c24b1f..1900249ff 100644
--- a/src/pint/pint_matrix.py
+++ b/src/pint/pint_matrix.py
@@ -406,14 +406,18 @@ def __call__(
derivative_params : list
The parameter list for the derivatives 'd_quantity_d_param'.
offset : bool, optional
- Add the an offset to the beginning of design matrix. Default is False.
- This is match the current phase offset in the design matrix.
+ Add the implicit offset to the beginning of design matrix. Default is False.
+ This is to match the current phase offset in the design matrix.
+ This option will be ignored if a `PhaseOffset` component is present in the timing model.
offset_padding : float, optional
if including offset, the value for padding.
"""
# Get derivative functions
deriv_func = getattr(model, self.deriv_func_name)
# Check if the derivate quantity a phase derivative
+
+ offset = offset and "PhaseOffset" not in model.components
+
params = ["Offset"] if offset else []
params += derivative_params
labels = []
@@ -450,12 +454,16 @@ def __call__(self, data, model, derivative_params, offset=True, offset_padding=1
derivative_params : list
The parameter list for the derivatives 'd_quantity_d_param'.
offset : bool, optional
- Add the an offset to the beginning of design matrix. Default is True.
+ Add the the implicit offset to the beginning of design matrix. Default is True.
+ This option will be ignored if a `PhaseOffset` component is present in the timing model.
offset_padding : float, optional
if including offset, the value for padding. Default is 1.0
"""
deriv_func = getattr(model, self.deriv_func_name)
# Check if the derivate quantity a phase derivative
+
+ offset = offset and "PhaseOffset" not in model.components
+
params = ["Offset"] if offset else []
params += derivative_params
labels = []
@@ -723,7 +731,7 @@ def prettyprint(self, prec=3, coordinatefirst=False, offset=False, usecolor=True
coordinatefirst : bool, optional
whether or not the output should be re-ordered to put the coordinates first (after the Offset, if present)
offset : bool, optional
- whether the absolute phase (i.e. "offset") should be shown
+ whether the implicit phase offset (i.e. "Offset") should be shown
usecolor : bool, optional
use color for "problem" CorrelationMatrix params
diff --git a/src/pint/pintk/colormodes.py b/src/pint/pintk/colormodes.py
index 123fc02b2..518509d68 100644
--- a/src/pint/pintk/colormodes.py
+++ b/src/pint/pintk/colormodes.py
@@ -360,7 +360,7 @@ def get_jumps(self):
model = self.application.psr.postfit_model
else:
model = self.application.psr.prefit_model
- if not "PhaseJump" in model.components:
+ if "PhaseJump" not in model.components:
return []
return model.get_jump_param_objects()
diff --git a/src/pint/pintk/paredit.py b/src/pint/pintk/paredit.py
index c523df071..fc7e0e59f 100644
--- a/src/pint/pintk/paredit.py
+++ b/src/pint/pintk/paredit.py
@@ -233,7 +233,7 @@ def setPulsar(self, psr, updates):
self.update_callbacks = updates
def call_updates(self):
- if not self.update_callbacks is None:
+ if self.update_callbacks is not None:
for ucb in self.update_callbacks:
ucb()
@@ -270,13 +270,12 @@ def applyChanges(self):
def writePar(self):
filename = tkFileDialog.asksaveasfilename(title="Choose output par file")
try:
- fout = open(filename, "w")
- fout.write(self.editor.get("1.0", "end-1c"))
- fout.close()
- log.info("Saved parfile to %s" % filename)
- except:
- if filename == () or filename == "":
- log.warning("Write Par cancelled.")
+ with open(filename, "w") as fout:
+ fout.write(self.editor.get("1.0", "end-1c"))
+ log.info(f"Saved parfile to {filename}")
+ except Exception:
+ if filename in [(), ""]:
+ log.warning("Writing par file cancelled.")
else:
log.warning("Could not save parfile to filename:\t%s" % filename)
diff --git a/src/pint/pintk/plk.py b/src/pint/pintk/plk.py
index 044bf87e4..14831aefb 100644
--- a/src/pint/pintk/plk.py
+++ b/src/pint/pintk/plk.py
@@ -8,8 +8,8 @@
from astropy.time import Time
import astropy.units as u
import matplotlib as mpl
+from matplotlib import figure
import numpy as np
-import matplotlib.figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import pint.pintk.pulsar as pulsar
@@ -1095,14 +1095,7 @@ def plotResiduals(self, keepAxes=False):
# Get the time of conjunction after T0 or TASC
tt = m.T0.value if hasattr(m, "T0") else m.TASC.value
mjd = m.conjunction(tt)
- if m.PB.value is not None:
- pb = m.PB.value
- elif m.FB0.quantity is not None:
- pb = (1 / m.FB0.quantity).to("day").value
- else:
- raise AttributeError(
- "Neither PB nor FB0 is present in the timing model."
- )
+ pb = m.pb()[0].to_value("day")
phs = (mjd - tt) / pb
self.plkAxes.plot([phs, phs], [ymin, ymax], "k-")
else:
diff --git a/src/pint/pintk/pulsar.py b/src/pint/pintk/pulsar.py
index 58e33e323..aa938fe31 100644
--- a/src/pint/pintk/pulsar.py
+++ b/src/pint/pintk/pulsar.py
@@ -1,4 +1,4 @@
-"""A wrapper around pulsar functions for pintkinter to use.
+"""A wrapper around pulsar functions for pintk to use.
This object will be shared between widgets in the main frame
and will contain the pre/post fit model, toas,
@@ -167,10 +167,7 @@ def resetAll(self):
def _delete_TOAs(self, toa_table):
del_inds = np.in1d(toa_table["index"], np.array(list(self.deleted)))
- if del_inds.sum() < len(toa_table):
- return toa_table[~del_inds]
- else:
- return None
+ return toa_table[~del_inds] if del_inds.sum() < len(toa_table) else None
def delete_TOAs(self, indices, selected):
# note: indices should be a list or an array
@@ -452,21 +449,11 @@ def add_jump(self, selected):
def getDefaultFitter(self, downhill=False):
if self.all_toas.wideband:
- if downhill:
- return "WidebandDownhillFitter"
- else:
- return "WidebandTOAFitter"
+ return "WidebandDownhillFitter" if downhill else "WidebandTOAFitter"
+ if self.prefit_model.has_correlated_errors:
+ return "DownhillGLSFitter" if downhill else "GLSFitter"
else:
- if self.prefit_model.has_correlated_errors:
- if downhill:
- return "DownhillGLSFitter"
- else:
- return "GLSFitter"
- else:
- if downhill:
- return "DownhillWLSFitter"
- else:
- return "WLSFitter"
+ return "DownhillWLSFitter" if downhill else "WLSFitter"
def print_chi2(self, selected):
# Select all the TOAs if none are explicitly set
diff --git a/src/pint/pintk/timedit.py b/src/pint/pintk/timedit.py
index 84d34a201..d2104f440 100644
--- a/src/pint/pintk/timedit.py
+++ b/src/pint/pintk/timedit.py
@@ -130,7 +130,7 @@ def setPulsar(self, psr, updates):
self.update_callbacks = updates
def call_updates(self):
- if not self.update_callbacks is None:
+ if self.update_callbacks is not None:
for ucb in self.update_callbacks:
ucb()
@@ -184,12 +184,11 @@ def applyChanges(self):
def writeTim(self):
filename = tkFileDialog.asksaveasfilename(title="Choose output tim file")
try:
- fout = open(filename, "w")
- fout.write(self.editor.get("1.0", "end-1c"))
- fout.close()
- log.info("Saved timfile to %s" % filename)
- except:
- if filename == () or filename == "":
- log.warning("Write Tim cancelled.")
+ with open(filename, "w") as fout:
+ fout.write(self.editor.get("1.0", "end-1c"))
+ log.info(f"Saved timfile to {filename}")
+ except Exception:
+ if filename in [(), ""]:
+ log.warning("Writing tim file cancelled.")
else:
log.warning("Could not save timfile to filename:\t%s" % filename)
diff --git a/src/pint/plot_utils.py b/src/pint/plot_utils.py
index 2e642b7a0..e95dae99e 100644
--- a/src/pint/plot_utils.py
+++ b/src/pint/plot_utils.py
@@ -1,7 +1,6 @@
#!/usr/bin/env python
import matplotlib.pyplot as plt
import numpy as np
-from pint.models.priors import GaussianBoundedRV
import astropy
import astropy.units as u
import astropy.time
@@ -184,15 +183,14 @@ def phaseogram_binned(
if weights is not None:
for ph, ww in zip(phases[idx], weights[idx]):
- bin = int(ph * bins)
- profile[bin] += ww
+ ibin = int(ph * bins)
+ profile[ibin] += ww
else:
for ph in phases[idx]:
- bin = int(ph * bins)
- profile[bin] += 1
+ ibin = int(ph * bins)
+ profile[ibin] += 1
- for i in range(bins):
- a.append(profile[i])
+ a.extend(profile[i] for i in range(bins))
a = np.array(a)
b = a.reshape(ntoa, bins)
@@ -250,7 +248,7 @@ def plot_priors(
samples from the chains is not supported. Can be created using
:meth:`pint.sampler.EmceeSampler.chains_to_dict`
maxpost_fitvals : list, optional
- The maximum posterier values returned from MCMC integration for each
+ The maximum posterior values returned from MCMC integration for each
fitter key. Plots a vertical dashed line to denote the maximum
posterior value in relation to the histogrammed samples. If the
values are not provided, then the lines are not plotted
@@ -271,15 +269,15 @@ def plot_priors(
larger than the other. The scaling is for visual purposes to clearly
plot the priors with the samples
"""
- keys = []
values = []
+ keys = []
for k, v in chains.items():
keys.append(k), values.append(v)
priors = []
x_range = []
counts = []
- for i in range(0, len(keys[:-1])):
+ for i in range(len(keys[:-1])):
values[i] = values[i][burnin:].flatten()
x_range.append(np.linspace(values[i].min(), values[i].max(), num=bins))
priors.append(getattr(model, keys[i]).prior.pdf(x_range[i]))
@@ -291,9 +289,8 @@ def plot_priors(
for i, p in enumerate(keys):
if i != len(keys[:-1]):
axs[i].set_xlabel(
- str(p)
- + ": Mean Value = "
- + str("{:.9e}".format(values[i].mean()))
+ f"{str(p)}: Mean Value = "
+ + "{:.9e}".format(values[i].mean())
+ " ("
+ str(getattr(model, p).units)
+ ")"
diff --git a/src/pint/polycos.py b/src/pint/polycos.py
index c725462dc..686576b9b 100644
--- a/src/pint/polycos.py
+++ b/src/pint/polycos.py
@@ -46,9 +46,7 @@
except (ModuleNotFoundError, ImportError) as e:
def tqdm(*args, **kwargs):
- if args:
- return args[0]
- return kwargs.get("iterable", None)
+ return args[0] if args else kwargs.get("iterable", None)
import pint.toa as toa
@@ -200,8 +198,7 @@ def evalfreq(self, t):
s = data2longdouble(0.0)
for i in range(1, self.ncoeff):
s += data2longdouble(i) * self.coeffs[i] * dt ** (i - 1)
- freq = self.f0 + s / 60.0
- return freq
+ return self.f0 + s / 60.0
def evalfreqderiv(self, t):
"""Return the frequency derivative at time t.
@@ -226,8 +223,7 @@ def evalfreqderiv(self, t):
* self.coeffs[i]
* dt ** (i - 2)
)
- freqd = s / (60.0 * 60.0)
- return freqd
+ return s / (60.0 * 60.0)
# Read polycos file data to table
@@ -305,7 +301,7 @@ def tempo_polyco_table_reader(filename):
fields = f.readline().split()
refPhaseInt, refPhaseFrac = fields[0].split(".")
refPhaseInt = np.longdouble(refPhaseInt)
- refPhaseFrac = np.longdouble("." + refPhaseFrac)
+ refPhaseFrac = np.longdouble(f".{refPhaseFrac}")
if refPhaseInt < 0:
refPhaseFrac = -refPhaseFrac
@@ -324,10 +320,9 @@ def tempo_polyco_table_reader(filename):
# Read coefficients
coeffs = []
- for i in range(-(nCoeff // -3)):
+ for _ in range(-(nCoeff // -3)):
line = f.readline()
- for c in line.split():
- coeffs.append(data2longdouble(c))
+ coeffs.extend(data2longdouble(c) for c in line.split())
coeffs = np.array(coeffs)
entry = PolycoEntry(
@@ -357,8 +352,7 @@ def tempo_polyco_table_reader(filename):
line = f.readline()
- pTable = table.Table(entries, meta={"name": "Polyco Data Table"})
- return pTable
+ return table.Table(entries, meta={"name": "Polyco Data Table"})
def tempo_polyco_table_writer(polycoTable, filename="polyco.dat"):
@@ -511,7 +505,7 @@ def _register(cls, formatlist=_polycoFormats):
cls.add_polyco_file_format(
fmt["format"], "w", None, fmt["write_method"]
)
- elif fmt["read_method"] is not None and fmt["write_method"] is not None:
+ elif fmt["read_method"] is not None:
cls.add_polyco_file_format(
fmt["format"], "rw", fmt["read_method"], fmt["write_method"]
)
@@ -533,9 +527,11 @@ def __init__(self, filename=None, format="tempo"):
log.info(f"Reading polycos from '{filename}'")
if format not in [f["format"] for f in self.polycoFormats]:
raise ValueError(
- "Unknown polyco file format '" + format + "'\n"
- "Please use function 'Polyco.add_polyco_file_format()'"
- " to register the format\n"
+ (
+ f"Unknown polyco file format '{format}" + "'\n"
+ "Please use function 'Polyco.add_polyco_file_format()'"
+ " to register the format\n"
+ )
)
else:
self.fileFormat = format
@@ -655,6 +651,7 @@ def read_polyco_file(self, filename, format="tempo"):
raise DeprecationWarning(
"Use `p=pint.polycos.Polycos.read()` rather than `p.read_polyco_file()`"
)
+
self.fileName = filename
if format not in [f["format"] for f in self.polycoFormats]:
@@ -727,8 +724,6 @@ def generate_polycos(
mjdStart = data2longdouble(mjdStart)
mjdEnd = data2longdouble(mjdEnd)
segLength = int(segLength)
- obsFreq = float(obsFreq)
-
# Use the planetary ephemeris specified in the model, if available.
if model.EPHEM.value is not None:
ephem = model.EPHEM.value
@@ -750,100 +745,100 @@ def generate_polycos(
)
tmids = data2longdouble(tmids) / MIN_PER_DAY
- # generate the ploynomial coefficents
- if method == "TEMPO":
- entryList = []
- # Using tempo1 method to create polycos
- # If you want to disable the progress bar, add disable=True to the tqdm() call.
- for tmid in tqdm(tmids, disable=not progress):
- tStart = tmid - mjdSpan / 2
- tStop = tmid + mjdSpan / 2
- nodes = np.linspace(tStart, tStop, numNodes)
-
- toaMid = toa.get_TOAs_array(
- (np.modf(tmid)[1], np.modf(tmid)[0]),
- obs=obs,
- freqs=obsFreq,
- ephem=ephem,
- )
- # toaMid = toa.get_TOAs_list(
- # [toa.TOA()],
- # )
-
- refPhase = model.phase(toaMid, abs_phase=True)
-
- # Create node toas(Time sample using TOA class)
- # toaList = [
- # toa.TOA(
- # (np.modf(toaNode)[1], np.modf(toaNode)[0]),
- # obs=obs,
- # freq=obsFreq,
- # )
- # for toaNode in nodes
- # ]
-
- # toas = toa.get_TOAs_list(toaList, ephem=ephem)
- toas = toa.get_TOAs_array(
- (np.modf(nodes)[0], np.modf(nodes)[1]),
- obs=obs,
- freqs=obsFreq,
- ephem=ephem,
- )
+ if method != "TEMPO":
+ raise NotImplementedError("Only TEMPO method has been implemented.")
+ entryList = []
+ obsFreq = float(obsFreq)
- ph = model.phase(toas, abs_phase=True)
- dt = (nodes - tmid) * MIN_PER_DAY
- rdcPhase = ph - refPhase
- rdcPhase = rdcPhase.int - (dt * model.F0.value * 60.0) + rdcPhase.frac
- dtd = dt.astype(float) # Truncate to double
- rdcPhased = rdcPhase.astype(float)
- coeffs = np.polyfit(dtd, rdcPhased, ncoeff - 1)[::-1]
-
- date, hms = Time(tmid, format="mjd", scale="utc").iso.split()
- yy, mm, dd = date.split("-")
- date = dd + "-" + MONTHS[int(mm) - 1] + "-" + yy[-2:]
- hms = float(hms.replace(":", ""))
-
- entry = PolycoEntry(
- tmid,
- segLength,
- refPhase.int,
- refPhase.frac,
- model.F0.value,
- ncoeff,
- coeffs,
- )
+ # Using tempo1 method to create polycos
+ # If you want to disable the progress bar, add disable=True to the tqdm() call.
+ for tmid in tqdm(tmids, disable=not progress):
+ tStart = tmid - mjdSpan / 2
+ tStop = tmid + mjdSpan / 2
+ nodes = np.linspace(tStart, tStop, numNodes)
+
+ toaMid = toa.get_TOAs_array(
+ (np.modf(tmid)[1], np.modf(tmid)[0]),
+ obs=obs,
+ freqs=obsFreq,
+ ephem=ephem,
+ )
+ # toaMid = toa.get_TOAs_list(
+ # [toa.TOA()],
+ # )
+
+ refPhase = model.phase(toaMid, abs_phase=True)
+
+ # Create node toas(Time sample using TOA class)
+ # toaList = [
+ # toa.TOA(
+ # (np.modf(toaNode)[1], np.modf(toaNode)[0]),
+ # obs=obs,
+ # freq=obsFreq,
+ # )
+ # for toaNode in nodes
+ # ]
+
+ # toas = toa.get_TOAs_list(toaList, ephem=ephem)
+ toas = toa.get_TOAs_array(
+ (np.modf(nodes)[0], np.modf(nodes)[1]),
+ obs=obs,
+ freqs=obsFreq,
+ ephem=ephem,
+ )
- entry_dict = OrderedDict()
- entry_dict["psr"] = model.PSR.value
- entry_dict["date"] = date
- entry_dict["utc"] = hms
- entry_dict["tmid"] = tmid
- entry_dict["dm"] = model.DM.value
- entry_dict["doppler"] = 0.0
- entry_dict["logrms"] = 0.0
- entry_dict["mjd_span"] = segLength
- entry_dict["t_start"] = entry.tstart
- entry_dict["t_stop"] = entry.tstop
- entry_dict["obs"] = obs
- entry_dict["obsfreq"] = obsFreq
-
- if model.is_binary:
- binphase = model.orbital_phase(toaMid, radians=False)[0]
- entry_dict["binary_phase"] = binphase
- b = model.get_components_by_category()["pulsar_system"][0]
- entry_dict["f_orbit"] = 1 / b.PB.value
-
- entry_dict["entry"] = entry
- entryList.append(entry_dict)
-
- pTable = table.Table(entryList, meta={"name": "Polyco Data Table"})
- out = cls()
- out.polycoTable = pTable
- if len(out.polycoTable) == 0:
- raise ValueError("Zero polycos found for table")
- return out
- else:
- raise NotImplementedError("Only TEMPO method has been implemented.")
+ ph = model.phase(toas, abs_phase=True)
+ dt = (nodes - tmid) * MIN_PER_DAY
+ rdcPhase = ph - refPhase
+ rdcPhase = rdcPhase.int - (dt * model.F0.value * 60.0) + rdcPhase.frac
+ dtd = dt.astype(float) # Truncate to double
+ rdcPhased = rdcPhase.astype(float)
+ coeffs = np.polyfit(dtd, rdcPhased, ncoeff - 1)[::-1]
+
+ date, hms = Time(tmid, format="mjd", scale="utc").iso.split()
+ yy, mm, dd = date.split("-")
+ date = f"{dd}-{MONTHS[int(mm) - 1]}-{yy[-2:]}"
+ hms = float(hms.replace(":", ""))
+
+ entry = PolycoEntry(
+ tmid,
+ segLength,
+ refPhase.int,
+ refPhase.frac,
+ model.F0.value,
+ ncoeff,
+ coeffs,
+ )
+
+ entry_dict = OrderedDict()
+ entry_dict["psr"] = model.PSR.value
+ entry_dict["date"] = date
+ entry_dict["utc"] = hms
+ entry_dict["tmid"] = tmid
+ entry_dict["dm"] = model.DM.value
+ entry_dict["doppler"] = 0.0
+ entry_dict["logrms"] = 0.0
+ entry_dict["mjd_span"] = segLength
+ entry_dict["t_start"] = entry.tstart
+ entry_dict["t_stop"] = entry.tstop
+ entry_dict["obs"] = obs
+ entry_dict["obsfreq"] = obsFreq
+
+ if model.is_binary:
+ binphase = model.orbital_phase(toaMid, radians=False)[0]
+ entry_dict["binary_phase"] = binphase
+ b = model.get_components_by_category()["pulsar_system"][0]
+ entry_dict["f_orbit"] = 1 / b.PB.value
+
+ entry_dict["entry"] = entry
+ entryList.append(entry_dict)
+
+ pTable = table.Table(entryList, meta={"name": "Polyco Data Table"})
+ out = cls()
+ out.polycoTable = pTable
+ if len(out.polycoTable) == 0:
+ raise ValueError("Zero polycos found for table")
+ return out
def write_polyco_file(self, filename="polyco.dat", format="tempo"):
"""Write Polyco table to a file.
@@ -858,9 +853,11 @@ def write_polyco_file(self, filename="polyco.dat", format="tempo"):
if format not in [f["format"] for f in self.polycoFormats]:
raise ValueError(
- "Unknown polyco file format '" + format + "'\n"
- "Please use function 'self.add_polyco_file_format()'"
- " to register the format\n"
+ (
+ f"Unknown polyco file format '{format}" + "'\n"
+ "Please use function 'self.add_polyco_file_format()'"
+ " to register the format\n"
+ )
)
self.polycoTable.write(filename, format=format)
@@ -963,9 +960,7 @@ def eval_abs_phase(self, t):
# Maybe add sort function here, since the time has been masked.
phaseInt = np.hstack(phaseInt).value
phaseFrac = np.hstack(phaseFrac).value
- absPhase = Phase(phaseInt, phaseFrac)
-
- return absPhase
+ return Phase(phaseInt, phaseFrac)
def eval_spin_freq(self, t):
"""
diff --git a/src/pint/pulsar_ecliptic.py b/src/pint/pulsar_ecliptic.py
index 1199692cc..edeb8e75a 100644
--- a/src/pint/pulsar_ecliptic.py
+++ b/src/pint/pulsar_ecliptic.py
@@ -46,7 +46,7 @@ class PulsarEcliptic(coord.BaseCoordinateFrame):
"""
default_representation = coord.SphericalRepresentation
- # NOTE: The feature below needs astropy verison 2.0. Disable it right now
+ # NOTE: The feature below needs astropy version 2.0. Disable it right now
default_differential = coord.SphericalCosLatDifferential
obliquity = QuantityAttribute(default=OBL["DEFAULT"], unit=u.arcsec)
@@ -54,11 +54,11 @@ def __init__(self, *args, **kwargs):
if "ecl" in kwargs:
try:
kwargs["obliquity"] = OBL[kwargs["ecl"]]
- except KeyError:
+ except KeyError as e:
raise ValueError(
"No obliquity " + kwargs["ecl"] + " provided. "
"Check your pint/datafile/ecliptic.dat file."
- )
+ ) from e
del kwargs["ecl"]
super().__init__(*args, **kwargs)
diff --git a/src/pint/pulsar_mjd.py b/src/pint/pulsar_mjd.py
index 8903ef57e..3c0f14d8d 100644
--- a/src/pint/pulsar_mjd.py
+++ b/src/pint/pulsar_mjd.py
@@ -104,10 +104,10 @@ def set_jds(self, val1, val2):
def value(self):
if self._scale == "utc":
mjd1, mjd2 = jds_to_mjds_pulsar(self.jd1, self.jd2)
- return mjd1 + mjd2
else:
mjd1, mjd2 = jds_to_mjds(self.jd1, self.jd2)
- return mjd1 + mjd2
+
+ return mjd1 + mjd2
class MJDLong(TimeFormat):
@@ -257,9 +257,7 @@ def time_from_mjd_string(s, scale="utc", format="pulsar_mjd"):
elif format.lower().startswith("mjd"):
return astropy.time.Time(val=s, scale=scale, format="mjd_string")
else:
- raise ValueError(
- "Format {} is not recognizable as an MJD format".format(format)
- )
+ raise ValueError(f"Format {format} is not recognizable as an MJD format")
def time_from_longdouble(t, scale="utc", format="pulsar_mjd"):
@@ -290,10 +288,7 @@ def time_to_longdouble(t):
double MJDs (near the present) is roughly 0.7 ns.
"""
- if t.format.startswith("pulsar_mjd"):
- return t.pulsar_mjd_long
- else:
- return t.mjd_long
+ return t.pulsar_mjd_long if t.format.startswith("pulsar_mjd") else t.mjd_long
# Precision-aware conversion functions
@@ -330,10 +325,7 @@ def data2longdouble(data):
np.longdouble
"""
- if type(data) is str:
- return str2longdouble(data)
- else:
- return np.longdouble(data)
+ return str2longdouble(data) if type(data) is str else np.longdouble(data)
def quantity2longdouble_withunit(data):
@@ -473,7 +465,7 @@ def _str_to_mjds(s):
mjd_s.append("0")
imjd_s, fmjd_s = mjd_s
imjd = np.longdouble(int(imjd_s))
- fmjd = np.longdouble("0." + fmjd_s)
+ fmjd = np.longdouble(f"0.{fmjd_s}")
if ss.startswith("-"):
fmjd = -fmjd
imjd *= 10**expon
@@ -485,7 +477,7 @@ def _str_to_mjds(s):
mjd_s.append("0")
imjd_s, fmjd_s = mjd_s
imjd = int(imjd_s)
- fmjd = float("0." + fmjd_s)
+ fmjd = float(f"0.{fmjd_s}")
if ss.startswith("-"):
fmjd = -fmjd
return day_frac(imjd, fmjd)
@@ -494,20 +486,19 @@ def _str_to_mjds(s):
def str_to_mjds(s):
if isinstance(s, (str, bytes)):
return _str_to_mjds(s)
- else:
- imjd = np.empty_like(s, dtype=int)
- fmjd = np.empty_like(s, dtype=float)
- with np.nditer(
- [s, imjd, fmjd],
- flags=["refs_ok"],
- op_flags=[["readonly"], ["writeonly"], ["writeonly"]],
- ) as it:
- for si, i, f in it:
- si = si[()]
- if not isinstance(si, (str, bytes)):
- raise TypeError("Requires an array of strings")
- i[...], f[...] = _str_to_mjds(si)
- return it.operands[1], it.operands[2]
+ imjd = np.empty_like(s, dtype=int)
+ fmjd = np.empty_like(s, dtype=float)
+ with np.nditer(
+ [s, imjd, fmjd],
+ flags=["refs_ok"],
+ op_flags=[["readonly"], ["writeonly"], ["writeonly"]],
+ ) as it:
+ for si, i, f in it:
+ si = si[()]
+ if not isinstance(si, (str, bytes)):
+ raise TypeError("Requires an array of strings")
+ i[...], f[...] = _str_to_mjds(si)
+ return it.operands[1], it.operands[2]
def _mjds_to_str(mjd1, mjd2):
@@ -527,10 +518,7 @@ def _mjds_to_str(mjd1, mjd2):
def mjds_to_str(mjd1, mjd2):
r = _v_mjds_to_str(mjd1, mjd2)
- if r.shape == ():
- return r[()]
- else:
- return r
+ return r[()] if r.shape == () else r
# These routines are from astropy but were broken in < 3.2.2 and <= 2.0.15
diff --git a/src/pint/random_models.py b/src/pint/random_models.py
index 2ec7bfe9c..391bc2f2a 100644
--- a/src/pint/random_models.py
+++ b/src/pint/random_models.py
@@ -72,7 +72,7 @@ def random_models(
rss = []
random_models = []
- for i in range(iter):
+ for _ in range(iter):
# create a set of randomized parameters based on mean vector and covariance matrix
rparams_num = np.random.multivariate_normal(mean_vector, cov_matrix)
# scale params back to real units
diff --git a/src/pint/residuals.py b/src/pint/residuals.py
index 51fe0f3ac..65ad45ef0 100644
--- a/src/pint/residuals.py
+++ b/src/pint/residuals.py
@@ -1,7 +1,7 @@
"""Objects for comparing models to data.
These objects can be constructed directly, as ``Residuals(toas, model)``, or
-they are contructed during fitting operations with :class:`pint.fitter.Fitter`
+they are constructed during fitting operations with :class:`pint.fitter.Fitter`
objects, as ``fitter.residual``. Variants exist for arrival-time-only data
(:class:`pint.residuals.Residuals`) and for arrival times that come paired with
dispersion measures (:class:`pint.residuals.WidebandTOAResiduals`).
@@ -54,16 +54,17 @@ class also serves as a base class providing some infrastructure to support
----------
toas: :class:`pint.toa.TOAs`, optional
The input TOAs object. Default: None
- model: :class:`pint.models.timing_model.TimingModel`, optinonal
+ model: :class:`pint.models.timing_model.TimingModel`, optional
Input model object. Default: None
residual_type: str, optional
- The type of the resiudals. Default: 'toa'
+ The type of the residuals. Default: 'toa'
unit: :class:`astropy.units.Unit`, optional
- The defualt unit of the residuals. Default: u.s
+ The default unit of the residuals. Default: u.s
subtract_mean : bool
- Controls whether mean will be subtracted from the residuals
+ Controls whether mean will be subtracted from the residuals.
+ This option will be ignored if a `PhaseOffset` is present in the timing model.
use_weighted_mean : bool
- Controls whether mean compution is weighted (by errors) or not.
+ Controls whether mean computation is weighted (by errors) or not.
track_mode : None, "nearest", "use_pulse_numbers"
Controls how pulse numbers are assigned. ``"nearest"`` assigns
each TOA to the nearest integer pulse. ``"use_pulse_numbers"`` uses the
@@ -83,16 +84,14 @@ def __new__(
subtract_mean=True,
use_weighted_mean=True,
track_mode=None,
+ use_abs_phase=True,
):
if cls is Residuals:
try:
cls = residual_map[residual_type.lower()]
- except KeyError:
+ except KeyError as e:
raise ValueError(
- "'{}' is not a PINT supported residual. Currently "
- "supported data types are {}".format(
- residual_type, list(residual_map.keys())
- )
+ f"'{residual_type}' is not a PINT supported residual. Currently supported data types are {list(residual_map.keys())}"
)
return super().__new__(cls)
@@ -106,11 +105,20 @@ def __init__(
subtract_mean=True,
use_weighted_mean=True,
track_mode=None,
+ use_abs_phase=True,
):
self.toas = toas
self.model = model
self.residual_type = residual_type
- self.subtract_mean = subtract_mean
+
+ if "PhaseOffset" in model.components and subtract_mean:
+ log.debug(
+ "Disabling implicit `subtract_mean` because `PhaseOffset` is present in the timing model."
+ )
+ self.subtract_mean = subtract_mean and "PhaseOffset" not in model.components
+
+ self.use_abs_phase = use_abs_phase
+
self.use_weighted_mean = use_weighted_mean
if track_mode is None:
if getattr(self.model, "TRACK").value == "-2":
@@ -135,6 +143,7 @@ def __init__(
else:
self.phase_resids = None
self.time_resids = None
+
# delay chi-squared computation until needed to avoid infinite recursion
# also it's expensive
# only relevant if there are correlated errors
@@ -144,7 +153,7 @@ def __init__(
self.debug_info = {}
# We should be carefully for the other type of residuals
self.unit = unit
- # A flag to indentify if this residual object is combined with residual
+ # A flag to identify if this residual object is combined with residual
# class.
self._is_combined = False
@@ -224,13 +233,14 @@ def get_data_error(self, scaled=True):
scaled: bool, optional
If errors get scaled by the noise model.
"""
- if not scaled:
- return self.toas.get_errors()
- else:
- return self.model.scaled_toa_uncertainty(self.toas)
+ return (
+ self.model.scaled_toa_uncertainty(self.toas)
+ if scaled
+ else self.toas.get_errors()
+ )
def rms_weighted(self):
- """Compute weighted RMS of the residals in time."""
+ """Compute weighted RMS of the residuals in time."""
# Use scaled errors, if the noise model is not presented, it will
# return the raw errors
scaled_errors = self.get_data_error()
@@ -262,7 +272,7 @@ def get_PSR_freq(self, calctype="modelF0"):
assert calctype.lower() in ["modelf0", "taylor", "numerical"]
if calctype.lower() == "modelf0":
# TODO this function will be re-write and move to timing model soon.
- # The following is a temproary patch.
+ # The following is a temporary patch.
if "Spindown" in self.model.components:
F0 = self.model.F0.quantity
elif "P0" in self.model.params:
@@ -291,17 +301,49 @@ def get_PSR_freq(self, calctype="modelF0"):
elif calctype.lower() == "numerical":
return self.model.d_phase_d_toa(self.toas)
- def calc_phase_resids(self):
- """Compute timing model residuals in pulse phase."""
+ def calc_phase_resids(
+ self, subtract_mean=None, use_weighted_mean=None, use_abs_phase=None
+ ):
+ """Compute timing model residuals in pulse phase.
+
+ if ``subtract_mean`` or ``use_weighted_mean`` is None, will use the values set for the object itself
+
+ Parameters
+ ----------
+ subtract_mean : bool or None, optional
+ Subtract the mean of the residuals. This is ignored if the `PhaseOffset` component
+ is present in the model. Default is to use the class attribute.
+ use_weighted_mean : bool or None, optional
+ Whether to use weighted mean for mean subtraction. Default is to use the class attribute.
+ use_abs_phase : bool or None, optional
+ Whether to use absolute phase (w.r.t. the TZR TOA). Default is to use the class attribute.
+
+ Returns
+ -------
+ Phase
+ """
+
+ if subtract_mean is None:
+ subtract_mean = self.subtract_mean
+
+ if "PhaseOffset" in self.model.components and subtract_mean:
+ log.debug(
+ "Ignoring `subtract_mean` because `PhaseOffset` is present in the timing model."
+ )
+ subtract_mean = subtract_mean and "PhaseOffset" not in self.model.components
+
+ if use_weighted_mean is None:
+ use_weighted_mean = self.use_weighted_mean
+
+ if use_abs_phase is None:
+ use_abs_phase = self.use_abs_phase
# Read any delta_pulse_numbers that are in the TOAs table.
# These are for PHASE statements, -padd flags, as well as user-inserted phase jumps
# Check for the column, and if not there then create it as zeros
- try:
- delta_pulse_numbers = Phase(self.toas.table["delta_pulse_number"])
- except IndexError:
+ if "delta_pulse_number" not in self.toas.table.colnames:
self.toas.table["delta_pulse_number"] = np.zeros(len(self.toas.get_mjds()))
- delta_pulse_numbers = Phase(self.toas.table["delta_pulse_number"])
+ delta_pulse_numbers = Phase(self.toas.table["delta_pulse_number"])
# Track on pulse numbers, if requested
if self.track_mode == "use_pulse_numbers":
@@ -314,41 +356,38 @@ def calc_phase_resids(self):
# we need absolute phases, since TZRMJD serves as the pulse
# number reference.
modelphase = (
- self.model.phase(self.toas, abs_phase=True) + delta_pulse_numbers
+ self.model.phase(self.toas, abs_phase=use_abs_phase)
+ + delta_pulse_numbers
)
# First assign each TOA to the correct relative pulse number, including
# and delta_pulse_numbers (from PHASE lines or adding phase jumps in GUI)
i = pulse_num.copy()
f = np.zeros_like(pulse_num)
- c = np.isnan(pulse_num)
- if np.any(c):
+ if np.any(np.isnan(pulse_num)):
raise ValueError("Pulse numbers are missing on some TOAs")
- i[c] = 0
residualphase = modelphase - Phase(i, f)
# This converts from a Phase object to a np.float128
full = residualphase.int + residualphase.frac
- if np.any(c):
- full[c] -= np.round(full[c])
- # If not tracking then do the usual nearest pulse number calculation
elif self.track_mode == "nearest":
# Compute model phase
modelphase = self.model.phase(self.toas) + delta_pulse_numbers
# Here it subtracts the first phase, so making the first TOA be the
# reference. Not sure this is a good idea.
- if self.subtract_mean:
+ if subtract_mean:
modelphase -= Phase(modelphase.int[0], modelphase.frac[0])
# Here we discard the integer portion of the residual and replace it with 0
- # This is effectively selecting the nearst pulse to compute the residual to.
+ # This is effectively selecting the nearest pulse to compute the residual to.
residualphase = Phase(np.zeros_like(modelphase.frac), modelphase.frac)
# This converts from a Phase object to a np.float128
full = residualphase.int + residualphase.frac
else:
- raise ValueError("Invalid track_mode '{}'".format(self.track_mode))
+ raise ValueError(f"Invalid track_mode '{self.track_mode}'")
+
# If we are using pulse numbers, do we really want to subtract any kind of mean?
- if not self.subtract_mean:
+ if not subtract_mean:
return full
- if not self.use_weighted_mean:
+ if not use_weighted_mean:
mean = full.mean()
else:
# Errs for weighted sum. Units don't matter since they will
@@ -359,14 +398,66 @@ def calc_phase_resids(self):
)
w = 1.0 / (self.get_data_error().value ** 2)
mean, err = weighted_mean(full, w)
+
return full - mean
- def calc_time_resids(self, calctype="taylor"):
+ def calc_phase_mean(self, weighted=True):
+ """Calculate mean phase of residuals, optionally weighted
+
+ Parameters
+ ----------
+ weighted : bool, optional
+
+ Returns
+ -------
+ astropy.units.Quantity
+ """
+ r = self.calc_phase_resids(subtract_mean=False)
+ if not weighted:
+ return r.mean()
+ if np.any(self.get_data_error() == 0):
+ raise ValueError("Some TOA errors are zero - cannot calculate residuals")
+ w = 1.0 / (self.get_data_error().value ** 2)
+ mean, _ = weighted_mean(r, w)
+ return mean
+
+ def calc_time_mean(self, calctype="taylor", weighted=True):
+ """Calculate mean time of residuals, optionally weighted
+
+ Parameters
+ ----------
+ calctype : str, optional
+ Calculation time for phase to time conversion. See :meth:`pint.residuals.Residuals.calc_time_resids` for details.
+ weighted : bool, optional
+
+ Returns
+ -------
+ astropy.units.Quantity
+ """
+
+ r = self.calc_time_resids(calctype=calctype, subtract_mean=False)
+ if not weighted:
+ return r.mean()
+ if np.any(self.get_data_error() == 0):
+ raise ValueError("Some TOA errors are zero - cannot calculate residuals")
+ w = 1.0 / (self.get_data_error().value ** 2)
+ mean, _ = weighted_mean(r, w)
+ return mean
+
+ def calc_time_resids(
+ self,
+ calctype="taylor",
+ subtract_mean=None,
+ use_weighted_mean=None,
+ use_abs_phase=None,
+ ):
"""Compute timing model residuals in time (seconds).
Converts from phase residuals to time residuals using several possible ways
to calculate the frequency.
+ If ``subtract_mean`` or ``use_weighted_mean`` is None, will use the values set for the object itself
+
Parameters
----------
calctype : {'taylor', 'modelF0', 'numerical'}
@@ -374,6 +465,13 @@ def calc_time_resids(self, calctype="taylor"):
parameter from the model.
If `calctype` == "numerical", then try a numerical derivative
If `calctype` == "taylor", evaluate the frequency with a Taylor series
+ subtract_mean : bool or None, optional
+ Subtract the mean of the residuals. This is ignored if the `PhaseOffset` component
+ is present in the model. Default is to use the class attribute.
+ use_weighted_mean : bool or None, optional
+ Whether to use weighted mean for mean subtraction. Default is to use the class attribute.
+ use_abs_phase : bool or None, optional
+ Whether to use absolute phase (w.r.t. the TZR TOA). Default is to use the class attribute.
Returns
-------
@@ -384,9 +482,22 @@ def calc_time_resids(self, calctype="taylor"):
:meth:`pint.residuals.Residuals.get_PSR_freq`
"""
assert calctype.lower() in ["modelf0", "taylor", "numerical"]
- if self.phase_resids is None:
- self.phase_resids = self.calc_phase_resids()
- return (self.phase_resids / self.get_PSR_freq(calctype=calctype)).to(u.s)
+ if subtract_mean is None and use_weighted_mean is None:
+ # if we are using the defaults, save the calculation
+ if self.phase_resids is None:
+ self.phase_resids = self.calc_phase_resids(
+ subtract_mean=subtract_mean,
+ use_weighted_mean=use_weighted_mean,
+ use_abs_phase=use_abs_phase,
+ )
+ phase_resids = self.phase_resids
+ else:
+ phase_resids = self.calc_phase_resids(
+ subtract_mean=subtract_mean,
+ use_weighted_mean=use_weighted_mean,
+ use_abs_phase=use_abs_phase,
+ )
+ return (phase_resids / self.get_PSR_freq(calctype=calctype)).to(u.s)
def calc_chi2(self, full_cov=False):
"""Return the weighted chi-squared for the model and toas.
@@ -428,20 +539,19 @@ def calc_chi2(self, full_cov=False):
toa_errors = self.get_data_error()
if (toa_errors == 0.0).any():
return np.inf
- else:
- # The self.time_resids is in the unit of "s", the error "us".
- # This is more correct way, but it is the slowest.
- # return (((self.time_resids / self.toas.get_errors()).decompose()**2.0).sum()).value
+ # The self.time_resids is in the unit of "s", the error "us".
+ # This is more correct way, but it is the slowest.
+ # return (((self.time_resids / self.toas.get_errors()).decompose()**2.0).sum()).value
- # This method is faster then the method above but not the most correct way
- # return ((self.time_resids.to(u.s) / self.toas.get_errors().to(u.s)).value**2.0).sum()
+ # This method is faster then the method above but not the most correct way
+ # return ((self.time_resids.to(u.s) / self.toas.get_errors().to(u.s)).value**2.0).sum()
- # This the fastest way, but highly depend on the assumption of time_resids and
- # error units. Ensure only a pure number is returned.
- try:
- return ((self.time_resids / toa_errors.to(u.s)) ** 2.0).sum().value
- except ValueError:
- return ((self.time_resids / toa_errors.to(u.s)) ** 2.0).sum()
+ # This the fastest way, but highly depend on the assumption of time_resids and
+ # error units. Ensure only a pure number is returned.
+ try:
+ return ((self.time_resids / toa_errors.to(u.s)) ** 2.0).sum().value
+ except ValueError:
+ return ((self.time_resids / toa_errors.to(u.s)) ** 2.0).sum()
def ecorr_average(self, use_noise_model=True):
"""Uses the ECORR noise model time-binning to compute "epoch-averaged" residuals.
@@ -576,12 +686,14 @@ def dof(self):
" class. The individual residual's dof is not "
"calculated correctly in the combined residuals."
)
- dof = len(self.dm_data)
+
# only get dm type of model component
# TODO provide a function in the timing model to get one type of component
- for cp in self.model.components.values():
- if Dispersion in cp.__class__.__bases__:
- dof -= len(cp.free_params_component)
+ dof = len(self.dm_data) - sum(
+ len(cp.free_params_component)
+ for cp in self.model.components.values()
+ if Dispersion in cp.__class__.__bases__
+ )
dof -= 1
return dof
@@ -593,18 +705,13 @@ def get_data_error(self, scaled=True):
scaled: bool, optional
If errors get scaled by the noise model.
"""
- if not scaled:
- return self.dm_error
- else:
- return self.model.scaled_dm_uncertainty(self.toas)
+ return self.model.scaled_dm_uncertainty(self.toas) if scaled else self.dm_error
def calc_resids(self):
model_value = self.get_model_value(self.toas)[self.relevant_toas]
resids = self.dm_data - model_value
if self.subtract_mean:
- if not self.use_weighted_mean:
- resids -= resids.mean()
- else:
+ if self.use_weighted_mean:
# Errs for weighted sum. Units don't matter since they will
# cancel out in the weighted sum.
if self.dm_error is None or np.any(self.dm_error == 0):
@@ -614,20 +721,21 @@ def calc_resids(self):
)
wm = np.average(resids, weights=1.0 / (self.dm_error**2))
resids -= wm
+ else:
+ resids -= resids.mean()
return resids
def calc_chi2(self):
data_errors = self.get_data_error()
if (data_errors == 0.0).any():
return np.inf
- else:
- try:
- return ((self.resids / data_errors) ** 2.0).sum().decompose().value
- except ValueError:
- return ((self.resids / data_errors) ** 2.0).sum().decompose()
+ try:
+ return ((self.resids / data_errors) ** 2.0).sum().decompose().value
+ except ValueError:
+ return ((self.resids / data_errors) ** 2.0).sum().decompose()
def rms_weighted(self):
- """Compute weighted RMS of the residals in time."""
+ """Compute weighted RMS of the residuals in time."""
scaled_errors = self.get_data_error()
if np.any(scaled_errors.value == 0):
raise ValueError(
@@ -688,7 +796,7 @@ class CombinedResiduals:
Parameters
----------
residuals: List of residual objects
- A list of different typs of residual objects
+ A list of different types of residual objects
Note
----
@@ -708,48 +816,41 @@ def __init__(self, residuals):
def model(self):
"""Return the single timing model object."""
raise AttributeError(
- "Combined redisuals object does not provide a "
- "single timing model object. Pleaes use the "
+ "Combined residuals object does not provide a "
+ "single timing model object. Please use the "
"dedicated subclass."
)
@property
def _combined_resids(self):
"""Residuals from all of the residual types."""
- all_resids = []
- for res in self.residual_objs.values():
- all_resids.append(res.resids_value)
+ all_resids = [res.resids_value for res in self.residual_objs.values()]
return np.hstack(all_resids)
@property
def _combined_data_error(self):
- # Since it is the combinde residual, the units are removed.
+ # Since it is the combined residual, the units are removed.
dr = self.data_error
return np.hstack([rv.value for rv in dr.values()])
@property
def unit(self):
- units = {}
- for k, v in self.residual_objs.items():
- units[k] = v.unit
- return units
+ return {k: v.unit for k, v in self.residual_objs.items()}
@property
def chi2(self):
- chi2 = 0
- for res in self.residual_objs.values():
- chi2 += res.chi2
- return chi2
+ return sum(res.chi2 for res in self.residual_objs.values())
@property
def data_error(self):
- errors = []
- for rs in self.residual_objs.values():
- errors.append((rs.residual_type, rs.get_data_error()))
+ errors = [
+ (rs.residual_type, rs.get_data_error())
+ for rs in self.residual_objs.values()
+ ]
return collections.OrderedDict(errors)
def rms_weighted(self):
- """Compute weighted RMS of the residals in time."""
+ """Compute weighted RMS of the residuals in time."""
if np.any(self._combined_data_error == 0):
raise ValueError(
diff --git a/src/pint/sampler.py b/src/pint/sampler.py
index 8a8f1348d..53d09bddf 100644
--- a/src/pint/sampler.py
+++ b/src/pint/sampler.py
@@ -94,14 +94,12 @@ def get_initial_pos(self, fitkeys, fitvals, fiterrs, errfact, **kwargs):
"""
if len(fitkeys) != len(fitvals):
raise ValueError(
- "Number of keys does ({}) not match number of values ({})!".format(
- len(fitkeys), len(fitvals)
- )
+ f"Number of keys does ({len(fitkeys)}) not match number of values ({len(fitvals)})!"
)
n_fit_params = len(fitvals)
pos = [
fitvals + fiterrs * errfact * np.random.randn(n_fit_params)
- for ii in range(self.nwalkers)
+ for _ in range(self.nwalkers)
]
# set starting params
# FIXME: what about other glitch phase parameters? This can't be right!
@@ -152,7 +150,7 @@ def get_chain(self):
"""
if self.sampler is None:
raise ValueError("MCMCSampler object has not called initialize_sampler()")
- return self.sampler.chain
+ return self.sampler.get_chain()
def chains_to_dict(self, names):
"""
@@ -160,7 +158,8 @@ def chains_to_dict(self, names):
"""
if self.sampler is None:
raise ValueError("MCMCSampler object has not called initialize_sampler()")
- chains = [self.sampler.chain[:, :, ii].T for ii in range(len(names))]
+ samples = np.transpose(self.sampler.get_chain(), (1, 0, 2))
+ chains = [samples[:, :, ii].T for ii in range(len(names))]
return dict(zip(names, chains))
def run_mcmc(self, pos, nsteps):
diff --git a/src/pint/scripts/convert_parfile.py b/src/pint/scripts/convert_parfile.py
index a393cdcdc..a0b7ec749 100644
--- a/src/pint/scripts/convert_parfile.py
+++ b/src/pint/scripts/convert_parfile.py
@@ -1,6 +1,8 @@
import argparse
import os
+from astropy import units as u
+
import pint.logging
from loguru import logger as log
@@ -8,6 +10,7 @@
from pint.models import get_model
from pint.models.parameter import _parfile_formats
+import pint.binaryconvert
__all__ = ["main"]
@@ -26,12 +29,36 @@ def main(argv=None):
choices=_parfile_formats,
default="pint",
)
+ parser.add_argument(
+ "-b",
+ "--binary",
+ help="Binary model for output",
+ choices=pint.binaryconvert.binary_types,
+ default=None,
+ )
parser.add_argument(
"-o",
"--out",
help=("Output filename [default=stdout]"),
default=None,
)
+ parser.add_argument(
+ "--nharms",
+ default=3,
+ type=int,
+ help="Number of harmonics (convert to ELL1H only)",
+ )
+ parser.add_argument(
+ "--usestigma",
+ action="store_true",
+ help="Use STIGMA instead of H4? (convert to ELL1H only)",
+ )
+ parser.add_argument(
+ "--kom",
+ type=float,
+ default=0,
+ help="KOM (longitude of ascending node) in deg (convert to DDK only)",
+ )
parser.add_argument(
"--log-level",
type=str,
@@ -57,6 +84,18 @@ def main(argv=None):
log.info(f"Reading '{args.input}'")
model = get_model(args.input)
+ if hasattr(model, "BINARY") and args.binary is not None:
+ log.info(f"Converting from {model.BINARY.value} to {args.binary}")
+ if args.binary == "ELL1H":
+ model = pint.binaryconvert.convert_binary(
+ model, args.binary, NHARMS=args.nharms, useSTIGMA=args.usestigma
+ )
+ elif args.binary == "DDK":
+ model = pint.binaryconvert.convert_binary(
+ model, args.binary, KOM=args.kom * u.deg
+ )
+ else:
+ model = pint.binaryconvert.convert_binary(model, args.binary)
output = model.as_parfile(format=args.format)
if args.out is None:
# just output to STDOUT
diff --git a/src/pint/scripts/event_optimize.py b/src/pint/scripts/event_optimize.py
index fcc2a1273..97ae53796 100755
--- a/src/pint/scripts/event_optimize.py
+++ b/src/pint/scripts/event_optimize.py
@@ -68,8 +68,7 @@ def read_gaussfitfile(gaussfitfile, proflen):
fwhms.append(float(line.split()[2]))
if not (len(phass) == len(ampls) == len(fwhms)):
log.warning(
- "Number of phases, amplitudes, and FWHMs are not the same in '%s'!"
- % gaussfitfile
+ f"Number of phases, amplitudes, and FWHMs are not the same in '{gaussfitfile}'!"
)
return 0.0
phass = np.asarray(phass)
@@ -248,6 +247,82 @@ def get_fit_keyvals(model, phs=0.0, phserr=0.1):
return fitkeys, np.asarray(fitvals), np.asarray(fiterrs)
+def run_sampler_autocorr(sampler, pos, nsteps, burnin, csteps=100, crit1=10):
+ """Runs the sampler and checks for chain convergence. Return the converged sampler and the mean autocorrelation time per 100 steps
+ Parameters
+ ----------
+ Sampler
+ The Emcee Ensemble Sampler
+ pos
+ The Initial positions of the walkers
+ nsteps : int
+ The number of integration steps
+ csteps : int
+ The interval at which the autocorrelation time is computed.
+ crit1 : int
+ The ratio of chain length to autocorrelation time to satisfy convergence
+ Returns
+ -------
+ The sampler and the mean autocorrelation times
+ Note
+ ----
+ The function checks for convergence of the chains every specified number of steps.
+ The criteria to check for convergence is:
+ 1. the chain has to be longer than the specified ratio times the estimated autocorrelation time
+ 2. the change in the estimated autocorrelation time is less than 1%
+ """
+ autocorr = []
+ old_tau = np.inf
+ converged1 = False
+ converged2 = False
+ for sample in sampler.sample(pos, iterations=nsteps, progress=True):
+ if not converged1:
+ # Checks if the iteration is past the burnin and checks for convergence at 10% tau change
+ if sampler.iteration >= burnin and sampler.iteration % csteps == 0:
+ tau = sampler.get_autocorr_time(tol=0, quiet=True)
+ if np.any(np.isnan(tau)):
+ continue
+ else:
+ x = np.mean(tau)
+ autocorr.append(x)
+ converged1 = np.all(tau * crit1 < sampler.iteration)
+ converged1 &= np.all(np.abs(old_tau - tau) / tau < 0.1)
+ # log.info("The mean estimated integrated autocorrelation step is: " + str(x))
+ old_tau = tau
+ if converged1:
+ log.info(
+ "10 % convergence reached with a mean estimated integrated step: "
+ + str(x)
+ )
+ else:
+ continue
+ else:
+ continue
+ else:
+ if not converged2:
+ # Checks for convergence at every 25 steps instead of 100 and tau change is 1%
+ if sampler.iteration % int(csteps / 4) == 0:
+ tau = sampler.get_autocorr_time(tol=0, quiet=True)
+ if np.any(np.isnan(tau)):
+ continue
+ else:
+ x = np.mean(tau)
+ autocorr.append(x)
+ converged2 = np.all(tau * crit1 < sampler.iteration)
+ converged2 &= np.all(np.abs(old_tau - tau) / tau < 0.01)
+ # log.info("The mean estimated integrated autocorrelation step is: " + str(x))
+ old_tau = tau
+ converge_step = sampler.iteration
+ else:
+ continue
+ if converged2 and (sampler.iteration - burnin) >= 1000:
+ log.info(f"Convergence reached at {converge_step}")
+ break
+ else:
+ continue
+ return autocorr
+
+
class emcee_fitter(Fitter):
def __init__(
self, toas=None, model=None, template=None, weights=None, phs=0.5, phserr=0.03
@@ -280,9 +355,7 @@ def lnprior(self, theta):
for val, key in zip(theta[:-1], self.fitkeys[:-1]):
lnsum += getattr(self.model, key).prior_pdf(val, logpdf=True)
# Add the phase term
- if theta[-1] > 1.0 or theta[-1] < 0.0:
- return -np.inf
- return lnsum
+ return -np.inf if theta[-1] > 1.0 or theta[-1] < 0.0 else lnsum
def lnposterior(self, theta):
"""
@@ -366,10 +439,7 @@ def prof_vs_weights(self, nbins=50, use_weights=False):
if nphotons <= 0:
hval = 0
else:
- if use_weights:
- hval = hmw(phss[good], weights=wgts)
- else:
- hval = hm(phss[good])
+ hval = hmw(phss[good], weights=wgts) if use_weights else hm(phss[good])
htests.append(hval)
if ii > 0 and ii % 2 == 0 and ii < 20:
r, c = ((ii - 2) // 2) // 3, ((ii - 2) // 2) % 3
@@ -389,22 +459,22 @@ def prof_vs_weights(self, nbins=50, use_weights=False):
if r == 2:
ax[r][c].set_xlabel("Phase")
f.suptitle(
- "%s: Minwgt / H-test / Approx # events" % self.model.PSR.value,
+ f"{self.model.PSR.value}: Minwgt / H-test / Approx # events",
fontweight="bold",
)
if use_weights:
- plt.savefig(ftr.model.PSR.value + "_profs_v_wgtcut.png")
+ plt.savefig(f"{ftr.model.PSR.value}_profs_v_wgtcut.png")
else:
- plt.savefig(ftr.model.PSR.value + "_profs_v_wgtcut_unweighted.png")
+ plt.savefig(f"{ftr.model.PSR.value}_profs_v_wgtcut_unweighted.png")
plt.close()
plt.plot(weights, htests, "k")
plt.xlabel("Min Weight")
plt.ylabel("H-test")
plt.title(self.model.PSR.value)
if use_weights:
- plt.savefig(ftr.model.PSR.value + "_htest_v_wgtcut.png")
+ plt.savefig(f"{ftr.model.PSR.value}_htest_v_wgtcut.png")
else:
- plt.savefig(ftr.model.PSR.value + "_htest_v_wgtcut_unweighted.png")
+ plt.savefig(f"{ftr.model.PSR.value}_htest_v_wgtcut_unweighted.png")
plt.close()
def plot_priors(self, chains, burnin, bins=100, scale=False):
@@ -551,6 +621,13 @@ def main(argv=None):
default=False,
action="store_true",
)
+ parser.add_argument(
+ "--no-autocorr",
+ help="Turn the autocorrelation check function off",
+ default=False,
+ action="store_true",
+ dest="noautocorr",
+ )
args = parser.parse_args(argv)
pint.logging.setup(
@@ -587,8 +664,8 @@ def main(argv=None):
modelin = pint.models.get_model(parfile)
# File name setup and clobber file check
- filepath = args.filepath if args.filepath else os.getcwd()
- basename = args.basename if args.basename else modelin.PSR.value
+ filepath = args.filepath or os.getcwd()
+ basename = args.basename or modelin.PSR.value
filename = os.path.join(filepath, basename)
check_file = os.path.isfile(
@@ -632,23 +709,17 @@ def main(argv=None):
except IOError:
pass
if ts is None:
- # Read event file and return list of TOA objects
- tl = fermi.load_Fermi_TOAs(
- eventfile, weightcolumn=weightcol, targetcoord=target, minweight=minWeight
+ ts = fermi.get_Fermi_TOAs(
+ eventfile,
+ weightcolumn=weightcol,
+ targetcoord=target,
+ minweight=minWeight,
+ minmjd=minMJD,
+ maxmjd=maxMJD,
+ ephem="DE421",
+ planets=False,
)
- # Limit the TOAs to ones in selected MJD range and above minWeight
- tl = [
- tl[ii]
- for ii in range(len(tl))
- if (
- tl[ii].mjd.value > minMJD
- and tl[ii].mjd.value < maxMJD
- and (weightcol is None or float(tl[ii].flags["weight"]) > minWeight)
- )
- ]
- log.info("There are %d events we will use" % len(tl))
- # Now convert to TOAs object and compute TDBs and posvels
- ts = toa.get_TOAs_list(tl, ephem="DE421", planets=False)
+ log.info("There are %d events we will use" % len(ts))
ts.filename = eventfile
# FIXME: writes to the TOA directory unconditionally
try:
@@ -832,7 +903,10 @@ def unwrapped_lnpost(theta):
pool=pool,
backend=backend,
)
- sampler.run_mcmc(pos, nsteps)
+ if args.noautocorr:
+ sampler.run_mcmc(pos, nsteps, progress=True)
+ else:
+ autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
pool.close()
pool.join()
except ImportError:
@@ -840,16 +914,22 @@ def unwrapped_lnpost(theta):
sampler = emcee.EnsembleSampler(
nwalkers, ndim, ftr.lnposterior, blobs_dtype=dtype, backend=backend
)
- sampler.run_mcmc(pos, nsteps)
+ if args.noautocorr:
+ sampler.run_mcmc(pos, nsteps, progress=True)
+ else:
+ autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
else:
sampler = emcee.EnsembleSampler(
nwalkers, ndim, ftr.lnposterior, blobs_dtype=dtype, backend=backend
)
- # The number is the number of points in the chain
- sampler.run_mcmc(pos, nsteps)
+ if args.noautocorr:
+ sampler.run_mcmc(pos, nsteps, progress=True)
+ else:
+ autocorr = run_sampler_autocorr(sampler, pos, nsteps, burnin)
def chains_to_dict(names, sampler):
- chains = [sampler.chain[:, :, ii].T for ii in range(len(names))]
+ samples = np.transpose(sampler.get_chain(), (1, 0, 2))
+ chains = [samples[:, :, ii].T for ii in range(len(names))]
return dict(zip(names, chains))
def plot_chains(chain_dict, file=False):
@@ -871,7 +951,9 @@ def plot_chains(chain_dict, file=False):
plot_chains(chains, file=filename + "_chains.png")
# Make the triangle plot.
- samples = sampler.chain[:, burnin:, :].reshape((-1, ndim))
+ samples = np.transpose(sampler.get_chain(discard=burnin), (1, 0, 2)).reshape(
+ (-1, ndim)
+ )
blobs = sampler.get_blobs()
lnprior_samps = blobs["lnprior"]
diff --git a/src/pint/scripts/event_optimize_MCMCFitter.py b/src/pint/scripts/event_optimize_MCMCFitter.py
index bde72ed93..bbab8ccee 100755
--- a/src/pint/scripts/event_optimize_MCMCFitter.py
+++ b/src/pint/scripts/event_optimize_MCMCFitter.py
@@ -358,7 +358,10 @@ def plot_chains(chain_dict, file=False):
plot_chains(chains, file=ftr.model.PSR.value + "_chains.png")
# Make the triangle plot.
- samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
+ # samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
+ samples = np.transpose(
+ sampler.sampler.get_chain(discard=burnin), (1, 0, 2)
+ ).reshape((-1, ftr.n_fit_params))
try:
import corner
diff --git a/src/pint/scripts/event_optimize_multiple.py b/src/pint/scripts/event_optimize_multiple.py
index e5aab3058..41316a000 100755
--- a/src/pint/scripts/event_optimize_multiple.py
+++ b/src/pint/scripts/event_optimize_multiple.py
@@ -1,6 +1,8 @@
#!/usr/bin/env python -W ignore::FutureWarning -W ignore::UserWarning -W ignore::DeprecationWarning
import argparse
+import contextlib
import sys
+import pickle
import astropy.units as u
import matplotlib.pyplot as plt
@@ -20,6 +22,8 @@
import pint.fermi_toas as fermi
import pint.models
import pint.toa as toa
+from pint.templates import lctemplate, lcfitters
+from pint.residuals import Residuals
from pint.mcmc_fitter import CompositeMCMCFitter
from pint.observatory.satellite_obs import get_satellite_observatory
from pint.sampler import EmceeSampler
@@ -37,9 +41,7 @@
def get_toas(evtfile, flags, tcoords=None, minweight=0, minMJD=0, maxMJD=100000):
if evtfile[:-3] == "tim":
- usepickle = False
- if "usepickle" in flags:
- usepickle = flags["usepickle"]
+ usepickle = flags["usepickle"] if "usepickle" in flags else False
ts = toa.get_TOAs(evtfile, usepickle=usepickle)
# Prune out of range MJDs
mask = np.logical_or(
@@ -49,14 +51,9 @@ def get_toas(evtfile, flags, tcoords=None, minweight=0, minMJD=0, maxMJD=100000)
ts.table = ts.table.group_by("obs")
else:
if "usepickle" in flags and flags["usepickle"]:
- try:
- picklefile = toa._check_pickle(evtfile)
- if not picklefile:
- picklefile = evtfile
- ts = toa.TOAs(picklefile)
- return ts
- except:
- pass
+ with contextlib.suppress(Exception):
+ picklefile = toa._check_pickle(evtfile) or evtfile
+ return toa.TOAs(picklefile)
weightcol = flags["weightcol"] if "weightcol" in flags else None
target = tcoords if weightcol == "CALC" else None
tl = fermi.load_Fermi_TOAs(
@@ -90,29 +87,26 @@ def load_eventfiles(infile, tcoords=None, minweight=0, minMJD=0, maxMJD=100000):
"""
lines = open(infile, "r").read().split("\n")
- eventinfo = {}
- eventinfo["toas"] = []
- eventinfo["lnlikes"] = []
- eventinfo["templates"] = []
- eventinfo["weightcol"] = []
- eventinfo["setweights"] = []
-
+ eventinfo = {
+ "toas": [],
+ "lnlikes": [],
+ "templates": [],
+ "weightcol": [],
+ "setweights": [],
+ }
for line in lines:
- log.info("%s" % line)
+ log.info(f"{line}")
if len(line) == 0:
continue
try:
words = line.split()
+ flags = {}
if len(words) > 3:
kvs = words[3:]
- flags = {}
for i in range(0, len(flags), 2):
k, v = kvs[i].lstrip("-"), kvs[i + 1]
flags[k] = v
- else:
- flags = {}
-
ts = get_toas(
words[0],
flags,
@@ -134,8 +128,8 @@ def load_eventfiles(infile, tcoords=None, minweight=0, minMJD=0, maxMJD=100000):
else:
eventinfo["weightcol"].append(None)
except Exception as e:
- log.error("%s" % str(e))
- log.error("Could not load %s" % line)
+ log.error(f"{str(e)}")
+ log.error(f"Could not load {line}")
return eventinfo
@@ -301,9 +295,7 @@ def main(argv=None):
try:
lnlike_funcs[i] = funcs[eventinfo["lnlikes"][i]]
except:
- raise ValueError(
- "%s is not a recognized function" % eventinfo["lnlikes"][i]
- )
+ raise ValueError(f'{eventinfo["lnlikes"][i]} is not a recognized function')
# Load in weights
ts = eventinfo["toas"][i]
@@ -339,13 +331,13 @@ def main(argv=None):
if tname[-6:] == "pickle" or tname == "analytic":
# Analytic template
try:
- gtemplate = cPickle.load(file(tname))
- except:
+ gtemplate = pickle.load(file(tname))
+ except Exception:
phases = (modelin.phase(ts)[1].value).astype(np.float64) % 1
gtemplate = lctemplate.get_gauss2()
lcf = lcfitters.LCFitter(gtemplate, phases, weights=wlist[i])
lcf.fit(unbinned=False)
- cPickle.dump(
+ pickle.dump(
gtemplate,
file("%s_template%d.pickle" % (jname, i), "wb"),
protocol=2,
@@ -398,7 +390,7 @@ def main(argv=None):
ftr.prof_vs_weights(use_weights=False)
sys.exit()
- ftr.phaseogram(plotfile=ftr.model.PSR.value + "_pre.png")
+ ftr.phaseogram(plotfile=f"{ftr.model.PSR.value}_pre.png")
like_start = ftr.lnlikelihood(ftr, ftr.get_parameters())
log.info("Starting Pulse Likelihood:\t%f" % like_start)
@@ -407,7 +399,7 @@ def main(argv=None):
if args.samples is None:
pos = None
else:
- chains = cPickle.load(file(args.samples))
+ chains = pickle.load(file(args.samples))
chains = np.reshape(chains, [nwalkers, -1, ndim])
pos = chains[:, -1, :]
@@ -431,11 +423,14 @@ def plot_chains(chain_dict, file=False):
plt.close()
chains = sampler.chains_to_dict(ftr.fitkeys)
- plot_chains(chains, file=ftr.model.PSR.value + "_chains.png")
+ plot_chains(chains, file=f"{ftr.model.PSR.value}_chains.png")
# Make the triangle plot.
- samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
- try:
+ # samples = sampler.sampler.chain[:, burnin:, :].reshape((-1, ftr.n_fit_params))
+ samples = np.transpose(
+ sampler.sampler.get_chain(discard=burnin), (1, 0, 2)
+ ).reshape((-1, ftr.n_fit_params))
+ with contextlib.suppress(ImportError):
import corner
fig = corner.corner(
@@ -445,23 +440,17 @@ def plot_chains(chain_dict, file=False):
truths=ftr.maxpost_fitvals,
plot_contours=True,
)
- fig.savefig(ftr.model.PSR.value + "_triangle.png")
+ fig.savefig(f"{ftr.model.PSR.value}_triangle.png")
plt.close()
- except ImportError:
- pass
-
# Make a phaseogram with the 50th percentile values
# ftr.set_params(dict(zip(ftr.fitkeys, np.percentile(samples, 50, axis=0))))
# Make a phaseogram with the best MCMC result
ftr.set_parameters(ftr.maxpost_fitvals)
- ftr.phaseogram(plotfile=ftr.model.PSR.value + "_post.png")
+ ftr.phaseogram(plotfile=f"{ftr.model.PSR.value}_post.png")
plt.close()
- # Write out the par file for the best MCMC parameter est
- f = open(ftr.model.PSR.value + "_post.par", "w")
- f.write(ftr.model.as_parfile())
- f.close()
-
+ with open(f"{ftr.model.PSR.value}_post.par", "w") as f:
+ f.write(ftr.model.as_parfile())
# Print the best MCMC values and ranges
ranges = map(
lambda v: (v[1], v[2] - v[1], v[1] - v[0]),
@@ -471,17 +460,12 @@ def plot_chains(chain_dict, file=False):
for name, vals in zip(ftr.fitkeys, ranges):
log.info("%8s:" % name + "%25.15g (+ %12.5g / - %12.5g)" % vals)
- # Put the same stuff in a file
- f = open(ftr.model.PSR.value + "_results.txt", "w")
-
- f.write("Post-MCMC values (50th percentile +/- (16th/84th percentile):\n")
- for name, vals in zip(ftr.fitkeys, ranges):
- f.write("%8s:" % name + " %25.15g (+ %12.5g / - %12.5g)\n" % vals)
-
- f.write("\nMaximum likelihood par file:\n")
- f.write(ftr.model.as_parfile())
- f.close()
-
- import cPickle
+ with open(f"{ftr.model.PSR.value}_results.txt", "w") as f:
+ f.write("Post-MCMC values (50th percentile +/- (16th/84th percentile):\n")
+ for name, vals in zip(ftr.fitkeys, ranges):
+ f.write("%8s:" % name + " %25.15g (+ %12.5g / - %12.5g)\n" % vals)
- cPickle.dump(samples, open(ftr.model.PSR.value + "_samples.pickle", "wb"))
+ f.write("\nMaximum likelihood par file:\n")
+ f.write(ftr.model.as_parfile())
+ with open(f"{ftr.model.PSR.value}_samples.pickle", "wb") as smppkl:
+ pickle.dump(samples, smppkl)
diff --git a/src/pint/scripts/fermiphase.py b/src/pint/scripts/fermiphase.py
index 6b07a7284..b1427e61f 100755
--- a/src/pint/scripts/fermiphase.py
+++ b/src/pint/scripts/fermiphase.py
@@ -167,11 +167,11 @@ def main(argv=None):
hdulist[1] = bt
if args.outfile is None:
# Overwrite the existing file
- log.info("Overwriting existing FITS file " + args.eventfile)
+ log.info(f"Overwriting existing FITS file {args.eventfile}")
hdulist.flush(verbose=True, output_verify="warn")
else:
# Write to new output file
- log.info("Writing output FITS file " + args.outfile)
+ log.info(f"Writing output FITS file {args.outfile}")
hdulist.writeto(
args.outfile, overwrite=True, checksum=True, output_verify="warn"
)
diff --git a/src/pint/scripts/photonphase.py b/src/pint/scripts/photonphase.py
index 6a2f50572..a0f5c5159 100755
--- a/src/pint/scripts/photonphase.py
+++ b/src/pint/scripts/photonphase.py
@@ -12,7 +12,7 @@
import pint.models
import pint.residuals
import pint.toa as toa
-from pint.event_toas import load_event_TOAs
+from pint.event_toas import get_event_TOAs
from pint.eventstats import h2sig, hm
from pint.fits_utils import read_fits_event_mjds
from pint.observatory.satellite_obs import get_satellite_observatory
@@ -152,11 +152,34 @@ def main(argv=None):
"The orbit file is not recognized. It is likely that this mission is not supported. "
"Please barycenter the event file using the official mission tools before processing with PINT"
)
+ # Read in model
+ modelin = pint.models.get_model(args.parfile)
+ use_planets = False
+ if "PLANET_SHAPIRO" in modelin.params:
+ if modelin.PLANET_SHAPIRO.value:
+ use_planets = True
+ if "AbsPhase" not in modelin.components:
+ log.error(
+ "TimingModel does not include AbsPhase component, which is required "
+ "for computing phases. Make sure you have TZR* parameters in your par file!"
+ )
+ raise ValueError("Model missing AbsPhase component.")
+
# Read event file and return list of TOA objects, if not using polycos
if args.polycos == False:
try:
- tl = load_event_TOAs(
- args.eventfile, telescope, minmjd=minmjd, maxmjd=maxmjd
+ # tl = load_event_TOAs(
+ # args.eventfile, telescope, minmjd=minmjd, maxmjd=maxmjd
+ # )
+ ts = get_event_TOAs(
+ args.eventfile,
+ telescope,
+ minmjd=minmjd,
+ maxmjd=maxmjd,
+ ephem=args.ephem,
+ include_bipm=args.use_bipm,
+ include_gps=args.use_gps,
+ planets=use_planets,
)
except KeyError:
log.error(
@@ -165,23 +188,10 @@ def main(argv=None):
sys.exit(1)
# Now convert to TOAs object and compute TDBs and posvels
- if len(tl) == 0:
+ if len(ts) == 0:
log.error("No TOAs, exiting!")
sys.exit(0)
- # Read in model
- modelin = pint.models.get_model(args.parfile)
- use_planets = False
- if "PLANET_SHAPIRO" in modelin.params:
- if modelin.PLANET_SHAPIRO.value:
- use_planets = True
- if "AbsPhase" not in modelin.components:
- log.error(
- "TimingModel does not include AbsPhase component, which is required "
- "for computing phases. Make sure you have TZR* parameters in your par file!"
- )
- raise ValueError("Model missing AbsPhase component.")
-
if args.addorbphase and (not hasattr(modelin, "binary_model_name")):
log.error(
"TimingModel does not include a binary model, which is required for "
@@ -230,14 +240,6 @@ def main(argv=None):
h = float(hm(phases))
print("Htest : {0:.2f} ({1:.2f} sigma)".format(h, h2sig(h)))
else: # Normal mode, not polycos
- ts = toa.get_TOAs_list(
- tl,
- ephem=args.ephem,
- include_bipm=args.use_bipm,
- include_gps=args.use_gps,
- planets=use_planets,
- tdb_method=args.tdbmethod,
- )
ts.filename = args.eventfile
# if args.fix:
# ts.adjust_TOAs(TimeDelta(np.ones(len(ts.table))*-1.0*u.s,scale='tt'))
diff --git a/src/pint/scripts/pintempo.py b/src/pint/scripts/pintempo.py
index 67ae7bfc1..7d21c6323 100755
--- a/src/pint/scripts/pintempo.py
+++ b/src/pint/scripts/pintempo.py
@@ -119,7 +119,7 @@ def main(argv=None):
xt = t.get_mjds()
ax.errorbar(xt, prefit_resids.to(u.us), t.get_errors().to(u.us), fmt="o")
ax.errorbar(xt, f.resids.time_resids.to(u.us), t.get_errors().to(u.us), fmt="x")
- ax.set_title("%s Timing Residuals" % m.PSR.value)
+ ax.set_title(f"{m.PSR.value} Timing Residuals")
ax.set_xlabel("MJD")
ax.set_ylabel("Residual (us)")
ax.grid()
diff --git a/src/pint/scripts/tcb2tdb.py b/src/pint/scripts/tcb2tdb.py
new file mode 100644
index 000000000..505d74d4e
--- /dev/null
+++ b/src/pint/scripts/tcb2tdb.py
@@ -0,0 +1,52 @@
+"""PINT-based tool for converting TCB par files to TDB."""
+
+import argparse
+
+from loguru import logger as log
+
+import pint.logging
+from pint.models.model_builder import ModelBuilder
+
+pint.logging.setup(level="INFO")
+
+__all__ = ["main"]
+
+
+def main(argv=None):
+ parser = argparse.ArgumentParser(
+ description="""`tcb2tdb` converts TCB par files to TDB.
+ Please note that this conversion is not exact and the timing model
+ should be re-fit to the TOAs.
+
+ The following parameters are converted to TDB:
+ 1. Spin frequency, its derivatives and spin epoch
+ 2. Sky coordinates, proper motion and the position epoch
+ 3. DM, DM derivatives and DM epoch
+ 4. Keplerian binary parameters and FB1
+
+ The following parameters are NOT converted although they are
+ in fact affected by the TCB to TDB conversion:
+ 1. Parallax
+ 2. TZRMJD and TZRFRQ
+ 3. DMX parameters
+ 4. Solar wind parameters
+ 5. Binary post-Keplerian parameters including Shapiro delay
+ parameters (except FB1)
+ 6. Jumps and DM Jumps
+ 7. FD parameters
+ 8. EQUADs
+ 9. Red noise parameters including FITWAVES, powerlaw red noise and
+ powerlaw DM noise parameters
+ """,
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("input_par", help="Input par file name (TCB)")
+ parser.add_argument("output_par", help="Output par file name (TDB)")
+
+ args = parser.parse_args(argv)
+
+ mb = ModelBuilder()
+ model = mb(args.input_par, allow_tcb=True)
+ model.write_parfile(args.output_par)
+
+ log.info(f"Output written to {args.output_par}.")
diff --git a/src/pint/simulation.py b/src/pint/simulation.py
index 222ab1cf3..45b9eb5c8 100644
--- a/src/pint/simulation.py
+++ b/src/pint/simulation.py
@@ -106,9 +106,8 @@ def get_fake_toa_clock_versions(model, include_bipm=False, include_gps=True):
if len(clk) == 2:
ctype, cvers = clk
if ctype == "TT" and cvers.startswith("BIPM"):
- if bipm_version is None:
- bipm_version = cvers
- log.info(f"Using CLOCK = {bipm_version} from the given model")
+ bipm_version = cvers
+ log.info(f"Using CLOCK = {bipm_version} from the given model")
else:
log.warning(
f'CLOCK = {model["CLOCK"].value} is not implemented. '
@@ -295,7 +294,7 @@ def make_fake_toas_uniform(
include_bipm=clk_version["include_bipm"],
bipm_version=clk_version["bipm_version"],
include_gps=clk_version["include_gps"],
- planets=model["PLANET_SHAPIRO"].value,
+ planets=model["PLANET_SHAPIRO"].value if "PLANET_SHAPIRO" in model else False,
)
ts.table["error"] = error
@@ -427,7 +426,7 @@ def make_fake_toas_fromtim(timfile, model, add_noise=False, name="fake"):
--------
:func:`make_fake_toas`
"""
- input_ts = pint.toa.get_TOAs(timfile)
+ input_ts = pint.toa.get_TOAs(timfile, planets=model.PLANET_SHAPIRO.value)
if input_ts.is_wideband():
dm_errors = input_ts.get_dm_errors()
diff --git a/src/pint/solar_system_ephemerides.py b/src/pint/solar_system_ephemerides.py
index a3ef8cead..76f4a6644 100644
--- a/src/pint/solar_system_ephemerides.py
+++ b/src/pint/solar_system_ephemerides.py
@@ -1,8 +1,10 @@
"""Solar system ephemeris downloading and setting support."""
+
import os
import astropy.coordinates
import astropy.units as u
+import contextlib
import numpy as np
from astropy.utils.data import download_file
from loguru import logger as log
@@ -44,7 +46,7 @@ def _load_kernel_link(ephem, link=None):
if link == "":
raise ValueError("Empty string is not a valid URL")
- mirrors = [m + f"{ephem}.bsp" for m in ephemeris_mirrors]
+ mirrors = [f"{m}{ephem}.bsp" for m in ephemeris_mirrors]
if link is not None:
mirrors = [link] + mirrors
astropy.coordinates.solar_system_ephemeris.set(
@@ -54,26 +56,18 @@ def _load_kernel_link(ephem, link=None):
def _load_kernel_local(ephem, path):
- ephem_bsp = "%s.bsp" % ephem
- if os.path.isdir(path):
- custom_path = os.path.join(path, ephem_bsp)
- else:
- custom_path = path
+ ephem_bsp = f"{ephem}.bsp"
+ custom_path = os.path.join(path, ephem_bsp) if os.path.isdir(path) else path
search_list = [custom_path]
- try:
+ with contextlib.suppress(FileNotFoundError):
search_list.append(pint.config.runtimefile(ephem_bsp))
- except FileNotFoundError:
- # If not found in runtimefile path, just continue. Error will be raised later if also not in "path"
- pass
for p in search_list:
if os.path.exists(p):
# .set() can accept a path to an ephemeris
astropy.coordinates.solar_system_ephemeris.set(ephem)
- log.info("Set solar system ephemeris to local file:\n\t{}".format(ephem))
+ log.info(f"Set solar system ephemeris to local file:\n\t{ephem}")
return
- raise FileNotFoundError(
- "ephemeris file {} not found in any of {}".format(ephem, search_list)
- )
+ raise FileNotFoundError(f"ephemeris file {ephem} not found in any of {search_list}")
def load_kernel(ephem, path=None, link=None):
@@ -125,15 +119,13 @@ def load_kernel(ephem, path=None, link=None):
)
# Links are just suggestions, try just plain loading
# Astropy may download something here, not from nanograv
- try:
+ # Exception here means it wasn't a standard astropy ephemeris
+ # or astropy can't access it (because astropy doesn't know about
+ # the nanograv mirrors)
+ with contextlib.suppress(ValueError, OSError):
astropy.coordinates.solar_system_ephemeris.set(ephem)
log.info(f"Set solar system ephemeris to {ephem} through astropy")
return
- except (ValueError, OSError):
- # Just means it wasn't a standard astropy ephemeris
- # or astropy can't access it (because astropy doesn't know about
- # the nanograv mirrors)
- pass
# If this raises an exception our last hope is gone so let it propagate
_load_kernel_link(ephem, link=link)
@@ -198,8 +190,7 @@ def objPosVel(obj1, obj2, t, ephem, path=None, link=None):
J2000 cartesian coordinate.
"""
if obj1.lower() == "ssb" and obj2.lower() != "ssb":
- obj2pv = objPosVel_wrt_SSB(obj2, t, ephem, path=path, link=link)
- return obj2pv
+ return objPosVel_wrt_SSB(obj2, t, ephem, path=path, link=link)
elif obj2.lower() == "ssb" and obj1.lower() != "ssb":
obj1pv = objPosVel_wrt_SSB(obj1, t, ephem, path=path, link=link)
return -obj1pv
diff --git a/src/pint/templates/lcenorm.py b/src/pint/templates/lcenorm.py
index 6735f4e06..d85a1fe67 100644
--- a/src/pint/templates/lcenorm.py
+++ b/src/pint/templates/lcenorm.py
@@ -53,8 +53,8 @@ def get_free_mask(self):
def get_bounds(self, free=True):
PI2 = np.pi * 0.5
- b1 = np.asarray([[0, PI2] for i in range(self.dim)])
- b2 = np.asarray([[-PI2, PI2] for i in range(self.dim)])
+ b1 = np.asarray([[0, PI2] for _ in range(self.dim)])
+ b2 = np.asarray([[-PI2, PI2] for _ in range(self.dim)])
if free:
return np.concatenate((b1[self.free], b2[self.slope_free]))
else:
@@ -111,6 +111,4 @@ def gradient(self, log10_ens, free=True):
rvals = np.empty((self.dim, 2 * self.dim, p.shape[1]))
rvals[:, : self.dim] = g0
rvals[:, self.dim :] = g0 * e
- if free:
- return rvals[:, np.append(self.free, self.slope_free)]
- return rvals
+ return rvals[:, np.append(self.free, self.slope_free)] if free else rvals
diff --git a/src/pint/templates/lceprimitives.py b/src/pint/templates/lceprimitives.py
index 677c5f0f6..ad0a1a0d0 100644
--- a/src/pint/templates/lceprimitives.py
+++ b/src/pint/templates/lceprimitives.py
@@ -14,7 +14,7 @@ def edep_gradient(self, grad_func, phases, log10_ens=3, free=False):
the difference in (log) energy.
However, there is one complication. Because of the bounds enforced
- by "_make_p", the gradient for the slope parameters vanishese at
+ by "_make_p", the gradient for the slope parameters vanishes at
some energies when the bound has saturated. These entries should be
zeroed.
"""
@@ -37,9 +37,7 @@ def edep_gradient(self, grad_func, phases, log10_ens=3, free=False):
hi_mask = p[i] >= bounds[i][1]
t[n + i, lo_mask | hi_mask] = 0
t[i, lo_mask | hi_mask] = 0
- if free:
- return t[np.append(self.free, self.slope_free)]
- return t
+ return t[np.append(self.free, self.slope_free)] if free else t
class LCEPrimitive(LCPrimitive):
@@ -48,11 +46,11 @@ def is_energy_dependent(self):
# TODO -- this is so awkward, fix it?
def parse_kwargs(self, kwargs):
- # acceptable keyword arguments, can be overriden by children
+ # acceptable keyword arguments, can be overridden by children
recognized_kwargs = ["p", "free", "slope", "slope_free"]
for key in kwargs.keys():
if key not in recognized_kwargs:
- raise ValueError("kwarg %s not recognized" % key)
+ raise ValueError(f"kwarg {key} not recognized")
self.__dict__.update(kwargs)
def _einit(self):
@@ -267,29 +265,31 @@ def init(self):
def base_int(self, x1, x2, log10_ens, index=0):
# TODO -- I haven't checked this code
raise NotImplementedError()
- e, gamma1, gamma2, x0 = self._make_p(log10_ens)
- # the only case where g1 and g2 can be different is if we're on the
- # 0th wrap, i.e. index=0; this also includes the case when we want
- # to use base_int to do a "full" integral
- g1 = np.where((x1 + index) < x0, gamma1, gamma2)
- g2 = np.where((x2 + index) >= x0, gamma2, gamma1)
- z1 = (x1 + index - x0) / g1
- z2 = (x2 + index - x0) / g2
- k = 2.0 / (gamma1 + gamma2) / PI
- return k * (g2 * np.arctan(z2) - g1 * np.arctan(z1))
+ # abhisrkckl: commented out unreachable code
+ # e, gamma1, gamma2, x0 = self._make_p(log10_ens)
+ # # the only case where g1 and g2 can be different is if we're on the
+ # # 0th wrap, i.e. index=0; this also includes the case when we want
+ # # to use base_int to do a "full" integral
+ # g1 = np.where((x1 + index) < x0, gamma1, gamma2)
+ # g2 = np.where((x2 + index) >= x0, gamma2, gamma1)
+ # z1 = (x1 + index - x0) / g1
+ # z2 = (x2 + index - x0) / g2
+ # k = 2.0 / (gamma1 + gamma2) / PI
+ # return k * (g2 * np.arctan(z2) - g1 * np.arctan(z1))
def random(self, log10_ens):
"""Use multinomial technique to return random photons from
both components."""
# TODO -- I haven't checked this code
raise NotImplementedError()
- if not isvector(log10_ens):
- n = log10_ens
- log10_ens = 3
- else:
- n = len(log10_ens)
- e, gamma1, gamma2, x0 = self._make_p(log10_ens) # only change
- return two_comp_mc(n, gamma1, gamma2, x0, cauchy.rvs)
+ # abhisrkckl: commented out unreachable code
+ # if not isvector(log10_ens):
+ # n = log10_ens
+ # log10_ens = 3
+ # else:
+ # n = len(log10_ens)
+ # e, gamma1, gamma2, x0 = self._make_p(log10_ens) # only change
+ # return two_comp_mc(n, gamma1, gamma2, x0, cauchy.rvs)
class LCEGaussian2(LCEWrappedFunction, LCGaussian2):
@@ -308,27 +308,29 @@ def init(self):
def base_int(self, x1, x2, log10_ens, index=0):
# TODO -- I haven't checked this code
raise NotImplementedError()
- e, width1, width2, x0 = self._make_p(log10_ens)
- w1 = np.where((x1 + index) < x0, width1, width2)
- w2 = np.where((x2 + index) >= x0, width2, width1)
- z1 = (x1 + index - x0) / w1
- z2 = (x2 + index - x0) / w2
- k1 = 2 * w1 / (width1 + width2)
- k2 = 2 * w2 / (width1 + width2)
- return 0.5 * (k2 * erf(z2 / ROOT2) - k1 * erf(z1 / ROOT2))
+ # abhisrkckl: commented out unreachable code
+ # e, width1, width2, x0 = self._make_p(log10_ens)
+ # w1 = np.where((x1 + index) < x0, width1, width2)
+ # w2 = np.where((x2 + index) >= x0, width2, width1)
+ # z1 = (x1 + index - x0) / w1
+ # z2 = (x2 + index - x0) / w2
+ # k1 = 2 * w1 / (width1 + width2)
+ # k2 = 2 * w2 / (width1 + width2)
+ # return 0.5 * (k2 * erf(z2 / ROOT2) - k1 * erf(z1 / ROOT2))
def random(self, log10_ens):
"""Use multinomial technique to return random photons from
both components."""
# TODO -- I haven't checked this code
raise NotImplementedError()
- if not isvector(log10_ens):
- n = log10_ens
- log10_ens = 3
- else:
- n = len(log10_ens)
- e, width1, width2, x0 = self.p
- return two_comp_mc(n, width1, width2, x0, norm.rvs)
+ # abhisrkckl: commented out unreachable code
+ # if not isvector(log10_ens):
+ # n = log10_ens
+ # log10_ens = 3
+ # else:
+ # n = len(log10_ens)
+ # e, width1, width2, x0 = self.p
+ # return two_comp_mc(n, width1, width2, x0, norm.rvs)
class LCEVonMises(LCEPrimitive, LCVonMises):
diff --git a/src/pint/templates/lcfitters.py b/src/pint/templates/lcfitters.py
index 0ae0b4640..ca5066f1c 100644
--- a/src/pint/templates/lcfitters.py
+++ b/src/pint/templates/lcfitters.py
@@ -4,7 +4,7 @@
a mixture model.
LCPrimitives are combined to form a light curve (LCTemplate).
-LCFitter then performs a maximum likielihood fit to determine the
+LCFitter then performs a maximum likelihood fit to determine the
light curve parameters.
LCFitter also allows fits to subsets of the phases for TOA calculation.
@@ -239,8 +239,7 @@ def chi(self, p, *args):
if not self.template.shift_mode and np.any(p < 0):
return 2e100 * np.ones_like(x) / len(x)
args[0].set_parameters(p)
- chi = (bg + (1 - bg) * self.template(x) - y) / yerr
- return chi
+ return (bg + (1 - bg) * self.template(x) - y) / yerr
def quick_fit(self):
t = self.template
@@ -262,9 +261,8 @@ def _fix_state(self, restore_state=None):
old_state.append(p.free[i])
if restore_state is not None:
p.free[i] = restore_state[counter]
- else:
- if i < (len(p.p) - 1):
- p.free[i] = False
+ elif i < (len(p.p) - 1):
+ p.free[i] = False
counter += 1
return old_state
@@ -341,25 +339,29 @@ def logl(phase):
else:
f = self.fit_fmin(fit_func, ftol=ftol)
if (ll0 > self.ll) or (ll0 == 2e20) or (np.isnan(ll0)):
- if unbinned_refit and np.isnan(ll0) and (not unbinned):
- if (self.binned_bins * 2) < 400:
- print(
- "Did not converge using %d bins... retrying with %d bins..."
- % (self.binned_bins, self.binned_bins * 2)
- )
- self.template.set_parameters(p0)
- self.ll = ll0
- self.fitvals = p0
- self.binned_bins *= 2
- self._hist_setup()
- return self.fit(
- quick_fit_first=quick_fit_first,
- unbinned=unbinned,
- use_gradient=use_gradient,
- positions_first=positions_first,
- estimate_errors=estimate_errors,
- prior=prior,
- )
+ if (
+ unbinned_refit
+ and np.isnan(ll0)
+ and (not unbinned)
+ and (self.binned_bins * 2) < 400
+ ):
+ print(
+ "Did not converge using %d bins... retrying with %d bins..."
+ % (self.binned_bins, self.binned_bins * 2)
+ )
+ self.template.set_parameters(p0)
+ self.ll = ll0
+ self.fitvals = p0
+ self.binned_bins *= 2
+ self._hist_setup()
+ return self.fit(
+ quick_fit_first=quick_fit_first,
+ unbinned=unbinned,
+ use_gradient=use_gradient,
+ positions_first=positions_first,
+ estimate_errors=estimate_errors,
+ prior=prior,
+ )
self.bad_p = self.template.get_parameters().copy()
self.bad_ll = self.ll
print("Failed likelihood fit -- resetting parameters.")
@@ -373,14 +375,12 @@ def logl(phase):
self.ll = ll0
self.fitvals = p0
return False
- if estimate_errors:
- if not self.hess_errors(use_gradient=use_gradient):
- # try:
- if try_bootstrap:
- self.bootstrap_errors(set_errors=True)
- # except ValueError:
- # print('Warning, could not estimate errors.')
- # self.template.set_errors(np.zeros_like(p0))
+ if (
+ estimate_errors
+ and not self.hess_errors(use_gradient=use_gradient)
+ and try_bootstrap
+ ):
+ self.bootstrap_errors(set_errors=True)
if not quiet:
print("Improved log likelihood by %.2f" % (self.ll - ll0))
return True
@@ -470,7 +470,7 @@ def fit_fmin(self, fit_func, ftol=1e-5):
def fit_cg(self):
from scipy.optimize import fmin_cg
- fit = fmin_cg(
+ return fmin_cg(
self.loglikelihood,
self.template.get_parameters(),
fprime=self.gradient,
@@ -478,7 +478,6 @@ def fit_cg(self):
full_output=1,
disp=1,
)
- return fit
def fit_bfgs(self):
from scipy.optimize import fmin_bfgs
@@ -523,10 +522,9 @@ def fit_l_bfgs_b(self):
x0 = self.template.get_parameters()
bounds = self.template.get_bounds()
- fit = fmin_l_bfgs_b(
+ return fmin_l_bfgs_b(
self.loglikelihood, x0, fprime=self.gradient, bounds=bounds, factr=1e-5
)
- return fit
def hess_errors(self, use_gradient=True):
"""Set errors from hessian. Fit should be called first..."""
@@ -750,9 +748,7 @@ def plot_ebands(
axzoom.tick_params(
labelleft=False, labelright=True, labelbottom=i == (nband - 1)
)
- if i < (nband - 1):
- pass
- else:
+ if i >= nband - 1:
axzoom.set_xticks([0.2, 0.4, 0.6, 0.8, 1.0])
axzoom.set_xlabel("")
axzoom.set_ylabel("")
@@ -772,7 +768,7 @@ def plot_ebands(
fig.supxlabel("Phase")
except AttributeError:
axes[-1].set_xlabel("Phase")
- if len(axzooms) > 0:
+ if axzooms:
axzooms[-1].set_xlabel("Phase")
for ax in axes:
ax.set_ylabel("Normalized Profile")
@@ -862,10 +858,7 @@ def bic(self, template=None):
else:
template = self.template
nump = len(self.template.get_parameters())
- if self.weights is None:
- n = len(self.phases)
- else:
- n = self.weights.sum()
+ n = len(self.phases) if self.weights is None else self.weights.sum()
ts = nump * np.log(n) + 2 * self()
self.template = template
return ts
@@ -875,11 +868,7 @@ def hessian(m, mf, *args, **kwargs):
"""Calculate the Hessian; mf is the minimizing function, m is the model,args additional arguments for mf."""
p = m.get_parameters().copy()
p0 = p.copy() # sacrosanct copy
- if "delt" in kwargs.keys():
- delta = kwargs["delt"]
- else:
- delta = [0.01] * len(p)
-
+ delta = kwargs.get("delt", [0.01] * len(p))
hessian = np.zeros([len(p), len(p)])
for i in range(len(p)):
delt = delta[i]
@@ -1080,9 +1069,7 @@ def f(x, i):
p0[i] = par[i] + x
delta_ll = logl(p0) - ll0 - 0.5
p0[i] = par[i]
- if abs(delta_ll) < 0.05:
- return 0
- return delta_ll
+ return 0 if abs(delta_ll) < 0.05 else delta_ll
for i in range(len(par)):
if f(maxstep, i) <= 0:
diff --git a/src/pint/templates/lcnorm.py b/src/pint/templates/lcnorm.py
index 4f3de386a..dcab992aa 100644
--- a/src/pint/templates/lcnorm.py
+++ b/src/pint/templates/lcnorm.py
@@ -154,9 +154,7 @@ def set_parameters(self, p, free=True):
self.p[:] = p
def get_parameters(self, free=True):
- if free:
- return self.p[self.free]
- return self.p
+ return self.p[self.free] if free else self.p
def get_parameter_names(self, free=True):
return [p for (p, b) in zip(self.pnames, self.free) if b]
@@ -182,12 +180,11 @@ def get_bounds(self, free=True):
"""Angles are always [0,pi/2)."""
PI2 = np.pi * 0.5
if free:
- return [[0, PI2] for ix, x in enumerate(self.free) if x]
- return [[0, PI2] for ix, x in enumerate(self.free)]
+ return [[0, PI2] for x in self.free if x]
+ return [[0, PI2] for _ in self.free]
def sanity_checks(self, eps=1e-6):
- t1 = np.abs(self().sum() - np.sin(self.p[0]) ** 2) < eps
- return t1
+ return np.abs(self().sum() - np.sin(self.p[0]) ** 2) < eps
def __call__(self, log10_ens=3):
"""Return the squared value of the Cartesian coordinates.
@@ -252,16 +249,11 @@ def _eindep_gradient(self, log10_ens=3, free=False):
for j in range(self.dim):
if j > i + 1:
break
- if j <= i:
- # these will always be sin^2 terms
- m[i, j] = n[i] * sp[j]
- else:
- # last term is cosine for all but last norm, but we won't
- # get to it here because j==i is the last term then
- m[i, j] = n[i] * cp[j]
- if free:
- return m[:, self.free]
- return m
+ # these will always be sin^2 terms if j<=i
+ # else, the last term is cosine for all but last norm, but we won't
+ # get to it here because j==i is the last term then
+ m[i, j] = n[i] * sp[j] if j <= i else n[i] * cp[j]
+ return m[:, self.free] if free else m
def gradient(self, log10_ens=3, free=False):
"""Return a matrix giving the value of the partial derivative
@@ -317,26 +309,20 @@ def hessian(self, log10_ens=3, free=False):
if k > i + 1:
break
if (j <= i) and (k <= i):
- if j != k:
- # two separate sines replacing sin^2
- m[i, j, k] = n[i] * sp[j] * sp[k]
- else:
- # diff same sine twice, getting a 2*cos
- m[i, j, k] = n[i] * 2 * c2p[j] / np.sin(p[j]) ** 2
+ m[i, j, k] = (
+ n[i] * sp[j] * sp[k]
+ if j != k
+ else n[i] * 2 * c2p[j] / np.sin(p[j]) ** 2
+ )
+ elif j != k:
+ if j == i + 1:
+ m[i, j, k] = n[i] * cp[j] * sp[k]
+ elif k == i + 1:
+ m[i, j, k] = n[i] * sp[j] * cp[k]
else:
- # at least one of j, k is a cos^2 term, so we pick up
- # a negative and need to divide by cos^2
- if j != k:
- if j == i + 1:
- m[i, j, k] = n[i] * cp[j] * sp[k]
- elif k == i + 1:
- m[i, j, k] = n[i] * sp[j] * cp[k]
- else:
- # both are the cos^2 term, so we get a -2*cos
- m[i, j, k] = n[i] * (-2) * c2p[j] / np.cos(p[j]) ** 2
- if free:
- return m[:, self.free, self.free]
- return m
+ # both are the cos^2 term, so we get a -2*cos
+ m[i, j, k] = n[i] * (-2) * c2p[j] / np.cos(p[j]) ** 2
+ return m[:, self.free, self.free] if free else m
def get_total(self):
"""Return the amplitude of all norms."""
@@ -433,18 +419,12 @@ def reorder_components(self, indices):
# slopes, probably need to use the gradient to convert!
def eval_string(self):
- """Return a string that can be evaled to instantiate a nearly-
+ """Return a string that can be evaluated to instantiate a nearly-
identical object."""
t = self()
if len(t.shape) > 1:
t = t[:, 0] # handle e-dep
- return "%s(%s,free=%s,slope=%s,slope_free=%s)" % (
- self.__class__.__name__,
- str(list(t)),
- str(list(self.free)),
- str(list(self.slope)) if hasattr(self, "slope") else None,
- str(list(self.slope_free)) if hasattr(self, "slope_free") else None,
- )
+ return f'{self.__class__.__name__}({list(t)},free={list(self.free)},slope={str(list(self.slope)) if hasattr(self, "slope") else None},slope_free={str(list(self.slope_free)) if hasattr(self, "slope_free") else None})'
def dict_string(self):
"""Round down to avoid input errors w/ normalization."""
@@ -460,20 +440,19 @@ def pretty_list(l, places=6, round_down=True):
r = l
fmt = "%." + "%d" % places + "f"
s = ", ".join([fmt % x for x in r])
- return "[" + s + "]"
+ return f"[{s}]"
return [
- "name = %s" % self.__class__.__name__,
- "norms = %s" % (pretty_list(t)),
- "free = %s" % (str(list(self.free))),
+ f"name = {self.__class__.__name__}",
+ f"norms = {pretty_list(t)}",
+ f"free = {list(self.free)}",
"slope = %s"
% (
pretty_list(self.slope, round_down=False)
if hasattr(self, "slope")
else None
),
- "slope_free = %s"
- % (str(list(self.slope_free)) if hasattr(self, "slope_free") else None),
+ f'slope_free = {str(list(self.slope_free)) if hasattr(self, "slope_free") else None}',
]
diff --git a/src/pint/templates/lcprimitives.py b/src/pint/templates/lcprimitives.py
index b627506e8..cbcf665ba 100644
--- a/src/pint/templates/lcprimitives.py
+++ b/src/pint/templates/lcprimitives.py
@@ -123,7 +123,7 @@ def approx_derivative(func, phases, log10_ens=None, order=1, eps=1e-7):
This is "dTemplate/dPhi."
"""
- if not ((order == 1) or (order == 2)):
+ if order not in [1, 2]:
raise NotImplementedError("Only 1st and 2nd derivs supported.")
phhi = np.mod(phases + eps, 1)
@@ -240,11 +240,11 @@ def _einit(self):
pass
def parse_kwargs(self, kwargs):
- # acceptable keyword arguments, can be overriden by children
+ # acceptable keyword arguments, can be overridden by children
recognized_kwargs = ["p", "free"]
for key in kwargs.keys():
if key not in recognized_kwargs:
- raise ValueError("kwarg %s not recognized" % key)
+ raise ValueError(f"kwarg {key} not recognized")
self.__dict__.update(kwargs)
def __call__(self, phases):
@@ -253,9 +253,7 @@ def __call__(self, phases):
)
def num_parameters(self, free=True):
- if free:
- return np.sum(self.free)
- return len(self.free)
+ return np.sum(self.free) if free else len(self.free)
def get_free_mask(self):
"""Return a mask with True if parameters are free, else False."""
@@ -344,9 +342,7 @@ def set_parameters(self, p, free=True):
# return np.all(self.p>=self.bounds[:,0]) and np.all(self.p<=self.bounds[:,1])
def get_parameters(self, free=True):
- if free:
- return self.p[self.free]
- return self.p
+ return self.p[self.free] if free else self.p
def get_parameter_names(self, free=True):
return [p for (p, b) in zip(self.pnames, self.free) if b]
@@ -358,14 +354,10 @@ def set_errors(self, errs):
return n
def get_errors(self, free=True):
- if free:
- return self.errors[self.free]
- return self.errors
+ return self.errors[self.free] if free else self.errors
def get_bounds(self, free=True):
- if free:
- return np.asarray(self.bounds)[self.free]
- return self.bounds
+ return np.asarray(self.bounds)[self.free] if free else self.bounds
def check_bounds(self, p=None):
b = np.asarray(self.bounds)
@@ -392,9 +384,7 @@ def center_gauss_prior(self, enable=False):
self.enable_gauss_prior()
def get_location(self, error=False):
- if error:
- return np.asarray([self.p[-1], self.errors[-1]])
- return self.p[-1]
+ return np.asarray([self.p[-1], self.errors[-1]]) if error else self.p[-1]
def set_location(self, loc):
self.p[-1] = loc
@@ -461,17 +451,16 @@ def random(self, n, log10_ens=3):
accept = (
rfunc(N) < self(cand_phases, log10_ens=log10_ens[mask]) / M[mask]
)
+ elif isvector(log10_ens):
+ accept = rfunc(N) < self(cand_phases, log10_ens=log10_ens[mask]) / M
else:
- if isvector(log10_ens):
- accept = rfunc(N) < self(cand_phases, log10_ens=log10_ens[mask]) / M
- else:
- accept = rfunc(N) < self(cand_phases, log10_ens=log10_ens) / M
+ accept = rfunc(N) < self(cand_phases, log10_ens=log10_ens) / M
rvals[indices[mask][accept]] = cand_phases[accept]
mask[indices[mask][accept]] = False
return rvals
def __str__(self):
- m = max([len(n) for n in self.pnames])
+ m = max(len(n) for n in self.pnames)
l = []
errors = self.errors if hasattr(self, "errors") else [0] * len(self.pnames)
for i in range(len(self.pnames)):
@@ -511,7 +500,7 @@ def sanity_checks(self, eps=1e-6):
# gradient test
try:
t4 = self.check_gradient(quiet=True)
- except:
+ except Exception:
t4 = False
# boundary conditions
t5 = abs(self(0) - self(1 - eps)) < eps
@@ -528,15 +517,9 @@ def sanity_checks(self, eps=1e-6):
return np.all([t1, t2, t3, t4, t5])
def eval_string(self):
- """Return a string that can be evaled to instantiate a nearly-
+ """Return a string that can be evaluated to instantiate a nearly-
identical object."""
- return "%s(p=%s,free=%s,slope=%s,slope_free=%s)" % (
- self.__class__.__name__,
- str(list(self.p)),
- str(list(self.free)),
- str(list(self.slope)) if hasattr(self, "slope") else None,
- str(list(self.slope_free)) if hasattr(self, "slope_free") else None,
- )
+ return f'{self.__class__.__name__}(p={list(self.p)},free={list(self.free)},slope={str(list(self.slope)) if hasattr(self, "slope") else None},slope_free={str(list(self.slope_free)) if hasattr(self, "slope_free") else None})'
def dict_string(self):
"""Return a string to express the object as a dictionary that can
@@ -545,16 +528,14 @@ def dict_string(self):
def pretty_list(l, places=5):
fmt = "%." + "%d" % places + "f"
s = ", ".join([fmt % x for x in l])
- return "[" + s + "]"
+ return f"[{s}]"
t = [
- "name = %s" % self.__class__.__name__,
- "p = %s" % (pretty_list(self.p)),
- "free = %s" % (str(list(self.free))),
- "slope = %s"
- % (pretty_list(self.slope) if hasattr(self, "slope") else None),
- "slope_free = %s"
- % (str(list(self.slope_free)) if hasattr(self, "slope_free") else None),
+ f"name = {self.__class__.__name__}",
+ f"p = {pretty_list(self.p)}",
+ f"free = {list(self.free)}",
+ f'slope = {pretty_list(self.slope) if hasattr(self, "slope") else None}',
+ f'slope_free = {str(list(self.slope_free)) if hasattr(self, "slope_free") else None}',
]
# return 'dict(\n'+'\n '.join(t)+'\n
return t
@@ -637,9 +618,7 @@ def gradient(self, phases, log10_ens=3, free=False):
if gn is not None:
for i in range(len(gn)):
results[i, :] += gn[i]
- if free:
- return results[self.free]
- return results
+ return results[self.free] if free else results
def gradient_derivative(self, phases, log10_ens=3, free=False):
"""Return the gradient evaluated at a vector of phases.
@@ -658,9 +637,7 @@ def gradient_derivative(self, phases, log10_ens=3, free=False):
if gn is not None:
for i in range(len(gn)):
results[i, :] += gn[i]
- if free:
- return results[self.free]
- return results
+ return results[self.free] if free else results
def hessian(self, phases, log10_ens=3, free=False):
"""Return the hessian evaluated at a vector of phases.
@@ -682,9 +659,7 @@ def hessian(self, phases, log10_ens=3, free=False):
raise NotImplementedError
# for i in range(len(gn)):
# results[i,:] += gn[i]
- if free:
- return results[self.free, self.free]
- return results
+ return results[self.free, self.free] if free else results
def derivative(self, phases, log10_ens=3, order=1):
"""Return the phase gradient (dprim/dphi) at a vector of phases.
@@ -929,26 +904,28 @@ def base_grad(self, phases, log10_ens=3, index=0):
def base_grad_deriv(self, phases, log10_ens=3, index=0):
raise NotImplementedError
- e, width, x0 = self._make_p(log10_ens)
- z = (phases + index - x0) / width
- f = (1.0 / (width * ROOT2PI)) * np.exp(-0.5 * z**2)
- q = f / width**2
- z2 = z**2
- return np.asarray([q * z * (3 - z2), q * (1 - z2)])
+ # @abhisrkckl: commented out unreachable code.
+ # e, width, x0 = self._make_p(log10_ens)
+ # z = (phases + index - x0) / width
+ # f = (1.0 / (width * ROOT2PI)) * np.exp(-0.5 * z**2)
+ # q = f / width**2
+ # z2 = z**2
+ # return np.asarray([q * z * (3 - z2), q * (1 - z2)])
def base_hess(self, phases, log10_ens=3, index=0):
raise NotImplementedError
- e, width, x0 = self._make_p(log10_ens=log10_ens)
- z = (phases + index - x0) / width
- f = (1.0 / (width * ROOT2PI)) * np.exp(-0.5 * z**2)
- q = f / width**2
- z2 = z**2
- rvals = np.empty((2, 2, len(z)))
- rvals[0, 0] = q * (z2**2 - 5 * z2 + 2)
- rvals[0, 1] = q * (z2 - 3) * z
- rvals[1, 1] = q * (z2 - 1)
- rvals[1, 0] = rvals[0, 1]
- return rvals
+ # @abhisrkckl: commented out unreachable code.
+ # e, width, x0 = self._make_p(log10_ens=log10_ens)
+ # z = (phases + index - x0) / width
+ # f = (1.0 / (width * ROOT2PI)) * np.exp(-0.5 * z**2)
+ # q = f / width**2
+ # z2 = z**2
+ # rvals = np.empty((2, 2, len(z)))
+ # rvals[0, 0] = q * (z2**2 - 5 * z2 + 2)
+ # rvals[0, 1] = q * (z2 - 3) * z
+ # rvals[1, 1] = q * (z2 - 1)
+ # rvals[1, 0] = rvals[0, 1]
+ # return rvals
def base_derivative(self, phases, log10_ens=3, index=0, order=1):
e, width, shape, x0 = self._make_p(log10_ens)
@@ -1054,9 +1031,7 @@ def gradient(self, phases, log10_ens=3, free=False):
f2 = f**2
g1 = f * (c1 / s1) - f2
g2 = f2 * (TWOPI / s1) * s
- if free:
- return np.asarray([g1, g2])[self.free]
- return np.asarray([g1, g2])
+ return np.asarray([g1, g2])[self.free] if free else np.asarray([g1, g2])
def derivative(self, phases, log10_ens=3, index=0, order=1):
"""Return the phase gradient (dprim/dphi) at a vector of phases.
@@ -1235,9 +1210,7 @@ def gradient(self, phases, log10_ens=3, free=False):
rvals = np.empty([2, len(phases)])
rvals[0] = f * kappa**2 * (I1 / I0 - cz)
rvals[1] = f * (TWOPI * kappa) * sz
- if free:
- return rvals[self.free]
- return rvals
+ return rvals[self.free] if free else rvals
def derivative(self, phases, log10_ens=3, order=1):
# NB -- same as the (-ve) loc gradient
@@ -1288,7 +1261,8 @@ def init(self):
def hwhm(self, right=False):
raise NotImplementedError()
- return self.p[0] * (2 * np.log(2)) ** 0.5
+ # @abhisrkckl: commented out unreachable code.
+ # return self.p[0] * (2 * np.log(2)) ** 0.5
def base_func(self, phases, log10_ens=3, index=0):
e, s, g, x0 = self._make_p(log10_ens)
@@ -1298,10 +1272,11 @@ def base_func(self, phases, log10_ens=3, index=0):
def base_grad(self, phases, log10_ens=3, index=0):
raise NotImplementedError()
- e, width, x0 = self._make_p(log10_ens)
- z = (phases + index - x0) / width
- f = (1.0 / (width * ROOT2PI)) * np.exp(-0.5 * z**2)
- return np.asarray([f / width * (z**2 - 1.0), f / width * z])
+ # @abhisrkckl: commented out unreachable code.
+ # e, width, x0 = self._make_p(log10_ens)
+ # z = (phases + index - x0) / width
+ # f = (1.0 / (width * ROOT2PI)) * np.exp(-0.5 * z**2)
+ # return np.asarray([f / width * (z**2 - 1.0), f / width * z])
def base_int(self, x1, x2, log10_ens=3, index=0):
e, s, g, x0 = self._make_p(log10_ens)
@@ -1313,15 +1288,14 @@ def base_int(self, x1, x2, log10_ens=3, index=0):
f2 = 1 - (1.0 + u2 / g) ** (1 - g)
if z1 * z2 < 0: # span the peak
return 0.5 * (f1 + f2)
- if z1 < 0:
- return 0.5 * (f1 - f2)
- return 0.5 * (f2 - f1)
+ return 0.5 * (f1 - f2) if z1 < 0 else 0.5 * (f2 - f1)
def random(self, n):
raise NotImplementedError()
- if hasattr(n, "__len__"):
- n = len(n)
- return np.mod(norm.rvs(loc=self.p[-1], scale=self.p[0], size=n), 1)
+ # @abhisrkckl: commented out unreachable code.
+ # if hasattr(n, "__len__"):
+ # n = len(n)
+ # return np.mod(norm.rvs(loc=self.p[-1], scale=self.p[0], size=n), 1)
class LCTopHat(LCPrimitive):
diff --git a/src/pint/templates/lctemplate.py b/src/pint/templates/lctemplate.py
index db2a92de5..8e14f0834 100644
--- a/src/pint/templates/lctemplate.py
+++ b/src/pint/templates/lctemplate.py
@@ -5,6 +5,8 @@
author: M. Kerr
"""
+
+import contextlib
import logging
from collections import defaultdict
from copy import deepcopy
@@ -40,10 +42,7 @@ def __init__(self, primitives, norms=None, cache_kwargs=None):
self.shift_mode = np.any([p.shift_mode for p in self.primitives])
if norms is None:
norms = np.ones(len(primitives)) / len(primitives)
- if hasattr(norms, "_make_p"):
- self.norms = norms
- else:
- self.norms = NormAngles(norms)
+ self.norms = norms if hasattr(norms, "_make_p") else NormAngles(norms)
self._sanity_checks()
self._cache = defaultdict(None)
self._cache_dirty = defaultdict(lambda: True)
@@ -57,11 +56,10 @@ def __setstate__(self, state):
_cache_dirty = defaultdict(lambda: True)
if not hasattr(self, "_cache_dirty"):
self._cache = defaultdict(None)
- self._cache_dirty = _cache_dirty
else:
# make _cache_dirty a defaultdict from a normal dict
_cache_dirty.update(self._cache_dirty)
- self._cache_dirty = _cache_dirty
+ self._cache_dirty = _cache_dirty
if not hasattr(self, "ncache"):
self.ncache = 1000
if not hasattr(self, "ph_edges"):
@@ -91,9 +89,7 @@ def has_bridge(self):
def __getitem__(self, index):
if index < 0:
index += len(self.primitives) + 1
- if index == len(self.primitives):
- return self.norms
- return self.primitives[index]
+ return self.norms if index == len(self.primitives) else self.primitives[index]
def __setitem__(self, index, value):
if index < 0:
@@ -104,6 +100,7 @@ def __setitem__(self, index, value):
self.primitives[index] = value
def __len__(self):
+ # sourcery skip: remove-unreachable-code
raise DeprecationWarning("I'd like to see if this is used.")
return len(self.primitives)
@@ -357,20 +354,15 @@ def integrate(self, phi1, phi2, log10_ens=3, suppress_bg=False):
phi2 = np.asarray(phi2)
if isvector(log10_ens):
assert len(log10_ens) == len(phi1)
- try:
+ with contextlib.suppress(TypeError):
assert len(phi1) == len(phi2)
- except TypeError:
- pass
norms = self.norms(log10_ens=log10_ens)
t = norms.sum(axis=0)
dphi = phi2 - phi1
rvals = np.zeros(phi1.shape, dtype=float)
for n, prim in zip(norms, self.primitives):
rvals += n * prim.integrate(phi1, phi2, log10_ens=log10_ens)
- if suppress_bg:
- return rvals * (1.0 / t)
- else:
- return (1 - t) * dphi + rvals
+ return rvals * (1.0 / t) if suppress_bg else (1 - t) * dphi + rvals
def cdf(self, x, log10_ens=3):
return self.integrate(np.zeros_like(x), x, log10_ens, suppress_bg=False)
@@ -401,9 +393,7 @@ def __call__(self, phases, log10_ens=3, suppress_bg=False, use_cache=False):
rvals, norms, norm = self._get_scales(phases, log10_ens)
for n, prim in zip(norms, self.primitives):
rvals += n * prim(phases, log10_ens=log10_ens)
- if suppress_bg:
- return rvals / norm
- return (1.0 - norm) + rvals
+ return rvals / norm if suppress_bg else (1.0 - norm) + rvals
def derivative(self, phases, log10_ens=3, order=1, use_cache=False):
"""Return the derivative of the template with respect to pulse
@@ -422,9 +412,7 @@ def single_component(self, index, phases, log10_ens=3, add_bg=False):
"""Evaluate a single component of template."""
n = self.norms(log10_ens=log10_ens)
rvals = self.primitives[index](phases, log10_ens=log10_ens) * n[index]
- if add_bg:
- return rvals + n.sum(axis=0)
- return rvals
+ return rvals + n.sum(axis=0) if add_bg else rvals
def gradient(self, phases, log10_ens=3, free=True, template_too=False):
r = np.empty((self.num_parameters(free), len(phases)))
@@ -451,7 +439,7 @@ def gradient(self, phases, log10_ens=3, free=True, template_too=False):
np.einsum("ij,ikj->kj", prim_terms, m, out=r[c:])
if template_too:
rvals[:] = 1 - norm
- for i in range(0, len(prim_terms)):
+ for i in range(len(prim_terms)):
rvals += (prim_terms[i] + 1) * norms[i]
return r, rvals
return r
@@ -460,6 +448,7 @@ def gradient_derivative(self, phases, log10_ens=3, free=False):
"""Return d/dphi(gradient). This is the derivative with respect
to pulse phase of the gradient with respect to the parameters.
"""
+ # sourcery skip: remove-unreachable-code
raise NotImplementedError() # is this used anymore?
free_mask = self.get_free_mask()
nparam = len(free_mask)
@@ -508,7 +497,7 @@ def check_derivative(self, atol=1e-7, rtol=1e-5, order=1, eps=1e-7, quiet=False)
)
def hessian(self, phases, log10_ens=3, free=True):
- """Return the hessian of the primitive and normaliation angles.
+ """Return the hessian of the primitive and normalization angles.
The primitives components are not coupled due to the additive form
of the template. However, because each normalization depends on
@@ -566,12 +555,10 @@ def hessian(self, phases, log10_ens=3, free=True):
r[c + j, c + k, :] += hnorm[i, j, k] * prim_terms[i]
r[c + k, c + j, :] = r[c + j, c + k, :]
- if free:
- return r[free_mask][:, free_mask]
- return r
+ return r[free_mask][:, free_mask] if free else r
def delta(self, index=None):
- """Return radio lag -- reckoned by default as the posittion of the first peak following phase 0."""
+ """Return radio lag -- reckoned by default as the position of the first peak following phase 0."""
if (index is not None) and (index <= (len(self.primitives))):
return self[index].get_location(error=True)
return self.Delta(delta=True)
@@ -591,9 +578,7 @@ def Delta(self, delta=False):
prim1 = p
p1, e1 = prim0.get_location(error=True)
p2, e2 = prim1.get_location(error=True)
- if delta:
- return p1, e1
- return (p2 - p1, (e1**2 + e2**2) ** 0.5)
+ return (p1, e1) if delta else (p2 - p1, (e1**2 + e2**2) ** 0.5)
def _sorted_prims(self):
def cmp(p1, p2):
@@ -668,16 +653,13 @@ def random(self, n, weights=None, log10_ens=3, return_partition=False):
n = int(round(n))
if len(self.primitives) == 0:
- if return_partition:
- return np.random.rand(n), [n]
- return np.random.rand(n)
+ return (np.random.rand(n), [n]) if return_partition else np.random.rand(n)
# check weights
if weights is None:
weights = np.ones(n)
- else:
- if len(weights) != n:
- raise ValueError("Provided weight vector does not match requested n.")
+ elif len(weights) != n:
+ raise ValueError("Provided weight vector does not match requested n.")
# check energies
if isvector(log10_ens):
@@ -719,9 +701,7 @@ def random(self, n, weights=None, log10_ens=3, return_partition=False):
assert not np.any(np.isnan(rvals)) # TMP
- if return_partition:
- return rvals, comps
- return rvals
+ return (rvals, comps) if return_partition else rvals
def swap_primitive(self, index, ptype=LCLorentzian):
"""Swap the specified primitive for a new one with the parameters
@@ -739,13 +719,12 @@ def delete_primitive(self, index, inplace=False):
raise ValueError("Template only has a single primitive.")
if index < 0:
index += len(prims)
- newprims = [deepcopy(p) for ip, p in enumerate(prims) if not index == ip]
+ newprims = [deepcopy(p) for ip, p in enumerate(prims) if index != ip]
newnorms = self.norms.delete_component(index)
- if inplace:
- self.primitives = newprims
- self.norms = newnorms
- else:
+ if not inplace:
return LCTemplate(newprims, newnorms)
+ self.primitives = newprims
+ self.norms = newnorms
def add_primitive(self, prim, norm=0.1, inplace=False):
"""[Convenience] -- return a new LCTemplate with the specified
@@ -756,11 +735,10 @@ def add_primitive(self, prim, norm=0.1, inplace=False):
return LCTemplate([prim], [1])
nprims = [deepcopy(prims[i]) for i in range(len(prims))] + [prim]
nnorms = self.norms.add_component(norm)
- if inplace:
- self.norms = nnorms
- self.primitives = nprims
- else:
+ if not inplace:
return LCTemplate(nprims, nnorms)
+ self.norms = nnorms
+ self.primitives = nprims
def order_primitives(self, order=0):
"""Re-order components in place.
@@ -797,7 +775,7 @@ def add_energy_dependence(self, index, slope_free=True):
elif comp.name == "VonMises":
constructor = LCEVonMises
else:
- raise NotImplementedError("%s not supported." % comp.name)
+ raise NotImplementedError(f"{comp.name} not supported.")
newcomp = constructor(p=comp.p)
newcomp.free[:] = comp.free
newcomp.slope_free[:] = slope_free
@@ -809,10 +787,9 @@ def get_eval_string(self):
ps = "\n".join(
("p%d = %s" % (i, p.eval_string()) for i, p in enumerate(self.primitives))
)
- prims = "[%s]" % (",".join(("p%d" % i for i in range(len(self.primitives)))))
- ns = "norms = %s" % (self.norms.eval_string())
- s = "%s(%s,norms)" % (self.__class__.__name__, prims)
- return s
+ prims = f'[{",".join("p%d" % i for i in range(len(self.primitives)))}]'
+ ns = f"norms = {self.norms.eval_string()}"
+ return f"{self.__class__.__name__}({prims},norms)"
def closest_to_peak(self, phases):
return min((p.closest_to_peak(phases) for p in self.primitives))
@@ -901,9 +878,9 @@ def write_profile(self, fname, nbin, integral=False, suppress_bg=False):
phases = np.linspace(0, 1, 2 * nbin + 1)
values = self(phases, suppress_bg=suppress_bg)
hi = values[2::2]
- lo = values[0:-1:2]
+ lo = values[:-1:2]
mid = values[1::2]
- bin_phases = phases[0:-1:2]
+ bin_phases = phases[:-1:2]
bin_values = 1.0 / (6 * nbin) * (hi + 4 * mid + lo)
bin_values *= 1.0 / bin_values.mean()
diff --git a/src/pint/toa.py b/src/pint/toa.py
index 75d5c309d..014403918 100644
--- a/src/pint/toa.py
+++ b/src/pint/toa.py
@@ -3,7 +3,7 @@
In particular, single TOAs are represented by :class:`pint.toa.TOA` objects, and if you
want to manage a collection of these we recommend you use a :class:`pint.toa.TOAs` object
as this makes certain operations much more convenient. You probably want to load one with
-:func:`pint.toa.get_TOAs` (from a ``.tim`` file) or :func:`pint.toa.get_TOAs_array` (from a
+:func:`pint.toa.get_TOAs` (from a ``.tim`` file) or :func:`pint.toa.get_TOAs_array` (from a
:class:`numpy.ndarray` or :class:`astropy.time.Time` object).
Warning
@@ -14,6 +14,8 @@
has moved to :mod:`pint.simulation`.
"""
+
+import contextlib
import copy
import gzip
import pickle
@@ -148,7 +150,8 @@ def get_TOAs(
timfile : str or list of strings or file-like
Filename, list of filenames, or file-like object containing the TOA data.
ephem : str or None
- The name of the solar system ephemeris to use; defaults to ``pint.toa.EPHEM_default`` if ``None``
+ The name of the solar system ephemeris to use; defaults to the EPHEM parameter
+ in the timing model (`model`) if it is given, otherwise defaults to ``pint.toa.EPHEM_default``.
include_bipm : bool or None
Whether to apply the BIPM clock correction. Defaults to True.
bipm_version : str or None
@@ -165,7 +168,7 @@ def get_TOAs(
model : pint.models.timing_model.TimingModel or None
If a valid timing model is passed, model commands (such as BIPM version,
planet shapiro delay, and solar system ephemeris) that affect TOA loading
- are applied.
+ are applied. The solar system ephemeris is superseded by the `ephem` parameter.
usepickle : bool
Whether to try to use pickle-based caching of loaded clock-corrected TOAs objects.
tdb_method : str
@@ -216,7 +219,11 @@ def get_TOAs(
f'CLOCK = {model["CLOCK"].value} is not implemented. '
f"Using TT({bipm_default}) instead."
)
- if planets is None and model["PLANET_SHAPIRO"].value:
+ if (
+ planets is None
+ and "PLANET_SHAPIRO" in model
+ and model["PLANET_SHAPIRO"].value
+ ):
planets = True
log.debug("Using PLANET_SHAPIRO = True from the given model")
@@ -345,16 +352,12 @@ def load_pickle(toafilename, picklefilename=None):
lf = None
for fn in picklefilenames:
- try:
+ with contextlib.suppress(IOError, pickle.UnpicklingError, ValueError):
with gzip.open(fn, "rb") as f:
lf = pickle.load(f)
- except (IOError, pickle.UnpicklingError, ValueError):
- pass
- try:
+ with contextlib.suppress(IOError, pickle.UnpicklingError, ValueError):
with open(fn, "rb") as f:
lf = pickle.load(f)
- except (IOError, pickle.UnpicklingError, ValueError):
- pass
if lf is not None:
lf.was_pickled = True
return lf
@@ -413,7 +416,7 @@ def get_TOAs_list(
t.commands = [] if commands is None else commands
t.filename = filename
t.hashes = {} if hashes is None else hashes
- if not any(["clkcorr" in f for f in t.table["flags"]]):
+ if all("clkcorr" not in f for f in t.table["flags"]):
t.apply_clock_corrections(
include_gps=include_gps,
include_bipm=include_bipm,
@@ -531,7 +534,7 @@ def _parse_TOA_line(line, fmt="Unknown"):
d["freq"] = float(line[25:34])
ii = line[34:41]
ff = line[42:55]
- MJD = (int(ii), float("0." + ff))
+ MJD = int(ii), float(f"0.{ff}")
phaseoffset = float(line[55:62])
if phaseoffset != 0:
raise ValueError(
@@ -540,10 +543,8 @@ def _parse_TOA_line(line, fmt="Unknown"):
d["error"] = float(line[63:71])
d["obs"] = get_observatory(line[79].upper()).name
elif fmt == "ITOA":
- raise RuntimeError("TOA format '%s' not implemented yet" % fmt)
- elif fmt in ["Blank", "Comment"]:
- pass
- else:
+ raise RuntimeError(f"TOA format '{fmt}' not implemented yet")
+ elif fmt not in ["Blank", "Comment"]:
raise RuntimeError(
f"Unable to identify TOA format for line {line!r}, expecting {fmt}"
)
@@ -623,7 +624,7 @@ def format_toa_line(
freq = 0.0 * u.MHz
flagstring = ""
if dm != 0.0 * pint.dmu:
- flagstring += "-dm {0:.5f}".format(dm.to(pint.dmu).value)
+ flagstring += "-dm {:.5f}".format(dm.to(pint.dmu).value)
# Here I need to append any actual flags
for flag in flags.keys():
v = flags[flag]
@@ -663,7 +664,7 @@ def format_toa_line(
freq = 0.0 * u.MHz
if obs.tempo_code is None:
raise ValueError(
- "Observatory {} does not have 1-character tempo_code!".format(obs.name)
+ f"Observatory {obs.name} does not have 1-character tempo_code!"
)
if dm != 0.0 * pint.dmu:
out = obs.tempo_code + " %13s%9.3f%20s%9.2f %9.4f\n" % (
@@ -681,7 +682,7 @@ def format_toa_line(
toaerr.to(u.us).value,
)
else:
- raise ValueError("Unknown TOA format ({0})".format(format))
+ raise ValueError(f"Unknown TOA format ({format})")
return out
@@ -914,7 +915,7 @@ def _cluster_by_gaps(t, gap):
class FlagDict(MutableMapping):
def __init__(self, *args, **kwargs):
- self.store = dict()
+ self.store = {}
self.update(dict(*args, **kwargs))
@staticmethod
@@ -930,7 +931,7 @@ def check_allowed_key(k):
if not isinstance(k, str):
raise ValueError(f"flag {k} must be a string")
if k.startswith("-"):
- raise ValueError(f"flags should be stored without their leading -")
+ raise ValueError("flags should be stored without their leading -")
if not FlagDict._key_re.match(k):
raise ValueError(f"flag {k} is not a valid flag")
@@ -1076,26 +1077,19 @@ def __init__(
scale = site.timescale
# First build a time without a location
# Note that when scale is UTC, must use pulsar_mjd format!
- if scale.lower() == "utc":
- fmt = "pulsar_mjd"
- else:
- fmt = "mjd"
+ fmt = "pulsar_mjd" if scale.lower() == "utc" else "mjd"
t = time.Time(arg1, arg2, scale=scale, format=fmt, precision=9)
# Now assign the site location to the Time, for use in the TDB conversion
# Time objects are immutable so you must make a new one to add the location!
- # Use the intial time to look up the observatory location
+ # Use the initial time to look up the observatory location
# (needed for moving observatories)
# The location is an EarthLocation in the ITRF (ECEF, WGS84) frame
try:
loc = site.earth_location_itrf(time=t)
except Exception:
- # Just add informmation and re-raise
- log.error(
- "Error computing earth_location_itrf at time {0}, {1}".format(
- t, type(t)
- )
- )
+ # Just add information and re-raise
+ log.error(f"Error computing earth_location_itrf at time {t}, {type(t)}")
raise
# Then construct the full time, with observatory location set
self.mjd = time.Time(t, location=loc, precision=9)
@@ -1105,7 +1099,7 @@ def __init__(
self.error = error.to(u.microsecond)
except u.UnitConversionError:
raise u.UnitConversionError(
- "Uncertainty for TOA with incompatible unit {0}".format(error)
+ f"Uncertainty for TOA with incompatible unit {error}"
)
else:
self.error = error * u.microsecond
@@ -1113,10 +1107,10 @@ def __init__(
if hasattr(freq, "unit"):
try:
self.freq = freq.to(u.MHz)
- except u.UnitConversionError:
+ except u.UnitConversionError as e:
raise u.UnitConversionError(
- "Frequency for TOA with incompatible unit {0}".format(freq)
- )
+ f"Frequency for TOA with incompatible unit {freq}"
+ ) from e
else:
self.freq = freq * u.MHz
if self.freq == 0.0 * u.MHz:
@@ -1136,9 +1130,15 @@ def __str__(self):
+ f": {self.error.value:6.3f} {self.error.unit} error at '{self.obs}' at {self.freq.value:.4f} {self.freq.unit}"
)
if self.flags:
- s += " " + str(self.flags)
+ s += f" {str(self.flags)}"
return s
+ def __eq__(self, other):
+ result = True
+ for p in ["mjd", "error", "obs", "freq", "flags"]:
+ result = result and getattr(self, p) == getattr(other, p)
+ return result
+
def as_line(self, format="Tempo2", name=None, dm=0 * pint.dmu):
"""Format TOA as a line for a ``.tim`` file."""
if name is not None:
@@ -1260,6 +1260,8 @@ class TOAs:
The TOA objects this TOAs should contain.
toatable : astropy.table.Table, optional
An existing TOA table
+ tzr : bool
+ Whether the TOAs object corresponds to a TZR TOA
Exactly one of these three parameters must be provided.
@@ -1296,9 +1298,11 @@ class TOAs:
available to use names as compatible with TEMPO as possible.
wideband : bool
Whether the TOAs also have wideband DM information
+ tzr : bool
+ Whether the TOAs object corresponds to a TZR TOA
"""
- def __init__(self, toafile=None, toalist=None, toatable=None):
+ def __init__(self, toafile=None, toalist=None, toatable=None, tzr=False):
# First, just make an empty container
self.commands = []
self.filename = None
@@ -1310,6 +1314,7 @@ def __init__(self, toafile=None, toalist=None, toatable=None):
self.hashes = {}
self.was_pickled = False
self.alias_translation = None
+ self.tzr = tzr
if (toalist is not None) and (toafile is not None):
raise ValueError("Cannot initialize TOAs from both file and list.")
@@ -1332,9 +1337,8 @@ def __init__(self, toafile=None, toalist=None, toatable=None):
if toalist is None:
raise ValueError("No TOAs found!")
- else:
- if not isinstance(toalist, (list, tuple)):
- raise ValueError("Trying to initialize TOAs from a non-list class")
+ if not isinstance(toalist, (list, tuple)):
+ raise ValueError("Trying to initialize TOAs from a non-list class")
self.table = build_table(toalist, filename=self.filename)
else:
self.table = copy.deepcopy(toatable)
@@ -1485,45 +1489,37 @@ def __setitem__(self, index, value):
self.table[column] = value
else:
self.table[column][subset] = value
- else:
- # dealing with flags
- if np.isscalar(value):
- if subset is None:
- for f in self.table["flags"]:
- if value:
- f[column] = str(value)
- else:
- try:
- del f[column]
- except KeyError:
- pass
- elif isinstance(subset, int):
- f = self.table["flags"][subset]
+ elif np.isscalar(value):
+ if subset is None:
+ for f in self.table["flags"]:
if value:
f[column] = str(value)
else:
- try:
+ with contextlib.suppress(KeyError):
del f[column]
- except KeyError:
- pass
+ elif isinstance(subset, int):
+ f = self.table["flags"][subset]
+ if value:
+ f[column] = str(value)
else:
- for f in self.table["flags"][subset]:
- if value:
- f[column] = str(value)
- else:
- try:
- del f[column]
- except KeyError:
- pass
+ with contextlib.suppress(KeyError):
+ del f[column]
else:
- if subset is None:
- subset = range(len(self))
- if len(subset) != len(value):
- raise ValueError(
- "Length of flag values must be equal to length of TOA subset"
- )
- for i in subset:
- self[column, i] = str(value[i])
+ for f in self.table["flags"][subset]:
+ if value:
+ f[column] = str(value)
+ else:
+ with contextlib.suppress(KeyError):
+ del f[column]
+ else:
+ if subset is None:
+ subset = range(len(self))
+ if len(subset) != len(value):
+ raise ValueError(
+ "Length of flag values must be equal to length of TOA subset"
+ )
+ for i in subset:
+ self[column, i] = str(value[i])
def __repr__(self):
return f"{len(self)} TOAs starting at MJD {self.first_MJD}"
@@ -1595,14 +1591,47 @@ def wideband(self):
"""Whether or not the data have wideband TOA values"""
return self.is_wideband()
+ def to_TOA_list(self, clkcorr=False):
+ """Turn a :class:`pint.toa.TOAs` object into a list of :class:`pint.toa.TOA` objects
+
+ This effectively undoes :func:`pint.toa.get_TOAs_list`, optionally undoing clock corrections too
+
+ Parameters
+ ----------
+ clkcorr : bool, optional
+ Whether or not to undo any clock corrections
+
+ Returns
+ -------
+ list :
+ Of :class:`pint.toa.TOA` objects
+ """
+ tl = []
+ clkcorrs = self.get_flag_value("clkcorr", 0, float)[0] * u.s
+ for i in range(len(self)):
+ t = self.table["mjd"][i]
+ f = self.table["flags"][i]
+ if not clkcorr:
+ t -= clkcorrs[i]
+ if "clkcorr" in f:
+ del f["clkcorr"]
+ tl.append(
+ TOA(
+ MJD=t,
+ error=self.table["error"][i] * self.table["error"].unit,
+ obs=self.table["obs"][i],
+ freq=self.table["freq"][i] * self.table["freq"].unit,
+ flags=f,
+ )
+ )
+ return tl
+
def is_wideband(self):
"""Whether or not the data have wideband TOA values"""
# there may be a more elegant way to do this
dm_data, valid_data = self.get_flag_value("pp_dm", as_type=float)
- if valid_data == []:
- return False
- return True
+ return valid_data != []
def get_all_flags(self):
"""Return a list of all the flags used by any TOA."""
@@ -1755,27 +1784,25 @@ def get_clusters(self, gap_limit=2 * u.h, add_column=False, add_flag=None):
chronologically from zero.
"""
if (
- ("clusters" not in self.table.colnames)
- or ("cluster_gap" not in self.table.meta)
- or (gap_limit != self.table.meta["cluster_gap"])
+ "clusters" in self.table.colnames
+ and "cluster_gap" in self.table.meta
+ and gap_limit == self.table.meta["cluster_gap"]
):
- clusters = _cluster_by_gaps(
- self.get_mjds().to_value(u.d), gap_limit.to_value(u.d)
- )
- if add_column:
- self.table.add_column(clusters, name="clusters")
- self.table.meta["cluster_gap"] = gap_limit
- log.debug(f"Added 'clusters' column to TOA table with gap={gap_limit}")
- if add_flag is not None:
- for i in range(len(clusters)):
- self.table["flags"][i][add_flag] = str(clusters[i])
- self.table.meta["cluster_gap"] = gap_limit
- log.debug(f"Added '{add_flag}' flag to TOA table with gap={gap_limit}")
-
- return clusters
-
- else:
return self.table["clusters"]
+ clusters = _cluster_by_gaps(
+ self.get_mjds().to_value(u.d), gap_limit.to_value(u.d)
+ )
+ if add_column:
+ self.table.add_column(clusters, name="clusters")
+ self.table.meta["cluster_gap"] = gap_limit
+ log.debug(f"Added 'clusters' column to TOA table with gap={gap_limit}")
+ if add_flag is not None:
+ for i in range(len(clusters)):
+ self.table["flags"][i][add_flag] = str(clusters[i])
+ self.table.meta["cluster_gap"] = gap_limit
+ log.debug(f"Added '{add_flag}' flag to TOA table with gap={gap_limit}")
+
+ return clusters
def get_highest_density_range(self, ndays=7 * u.d):
"""Print the range of mjds (default 7 days) with the most toas"""
@@ -1822,10 +1849,10 @@ def check_hashes(self, timfile=None):
if len(timfiles) != len(filenames):
return False
- for t, f in zip(timfiles, filenames):
- if pint.utils.compute_hash(t) != self.hashes[f]:
- return False
- return True
+ return all(
+ pint.utils.compute_hash(t) == self.hashes[f]
+ for t, f in zip(timfiles, filenames)
+ )
def select(self, selectarray):
"""Apply a boolean selection or mask array to the TOA table.
@@ -1941,26 +1968,28 @@ def remove_pulse_numbers(self):
if "pulse_number" in self.table.colnames:
del self.table["pulse_number"]
else:
- log.warning(
- f"Requested deleting of pulse numbers, but they are not present"
- )
+ log.warning("Requested deleting of pulse numbers, but they are not present")
def adjust_TOAs(self, delta):
"""Apply a time delta to TOAs.
Adjusts the time (MJD) of the TOAs by applying delta, which should
- have the same shape as ``self.table['mjd']``. This function does not change
+ be a scalar or have the same shape as ``self.table['mjd']``. This function does not change
the pulse numbers column, if present, but does recompute ``mjd_float``,
the TDB times, and the observatory positions and velocities.
Parameters
----------
- delta : astropy.time.TimeDelta
+ delta : astropy.time.TimeDelta or astropy.units.Quantity
The time difference to add to the MJD of each TOA
"""
col = self.table["mjd"]
+ if not isinstance(delta, (time.TimeDelta, u.Quantity)):
+ raise ValueError("Type of argument must be Quantity or TimeDelta")
if not isinstance(delta, time.TimeDelta):
- raise ValueError("Type of argument must be TimeDelta")
+ delta = time.TimeDelta(delta)
+ if delta.isscalar:
+ delta = time.TimeDelta(np.repeat(delta.sec, len(col)) * u.s)
if delta.shape != col.shape:
raise ValueError("Shape of mjd column and delta must be compatible")
for ii in range(len(col)):
@@ -2071,7 +2100,7 @@ def write_TOA_file(
del toacopy.table["flags"][i]["pn"]
else:
log.warning(
- f"'pulse_number' column exists but it is not being written out"
+ "'pulse_number' column exists but it is not being written out"
)
if (
"delta_pulse_number" in toacopy.table.columns
@@ -2153,18 +2182,15 @@ def apply_clock_corrections(
"""
# First make sure that we haven't already applied clock corrections
flags = self.table["flags"]
- if any(["clkcorr" in f for f in flags]):
- if all(["clkcorr" in f for f in flags]):
- log.warning("Clock corrections already applied. Not re-applying.")
- return
- else:
+ if any("clkcorr" in f for f in flags):
+ if any("clkcorr" not in f for f in flags):
# FIXME: could apply clock corrections to just the ones that don't have any
raise ValueError("Some TOAs have 'clkcorr' flag and some do not!")
+ log.warning("Clock corrections already applied. Not re-applying.")
+ return
# An array of all the time corrections, one for each TOA
log.debug(
- "Applying clock corrections (include_gps = {0}, include_bipm = {1})".format(
- include_gps, include_bipm
- )
+ f"Applying clock corrections (include_gps = {include_gps}, include_bipm = {include_bipm})"
)
corrections = np.zeros(self.ntoas) * u.s
# values of "-to" flags
@@ -2223,21 +2249,17 @@ def compute_TDBs(self, method="default", ephem=None):
self.table.remove_column("tdbld")
if ephem is None:
- if self.ephem is not None:
- ephem = self.ephem
- else:
+ if self.ephem is None:
log.warning(
f"No ephemeris provided to TOAs object or compute_TDBs. Using {EPHEM_default}"
)
ephem = EPHEM_default
- else:
- # If user specifies an ephemeris, make sure it is the same as the one already
- # in the TOA object, to prevent mixing.
- if (self.ephem is not None) and (ephem != self.ephem):
- log.error(
- "Ephemeris provided to compute_TDBs {0} is different than TOAs object "
- "ephemeris {1}! Using TDB ephemeris.".format(ephem, self.ephem)
- )
+ else:
+ ephem = self.ephem
+ elif (self.ephem is not None) and (ephem != self.ephem):
+ log.error(
+ f"Ephemeris provided to compute_TDBs {ephem} is different than TOAs object ephemeris {self.ephem}! Using TDB ephemeris."
+ )
self.ephem = ephem
log.debug(f"Using EPHEM = {self.ephem} for TDB calculation.")
# Compute in observatory groups
@@ -2295,23 +2317,17 @@ def compute_posvels(self, ephem=None, planets=None):
specified, set ``self.planets`` to this value.
"""
if ephem is None:
- if self.ephem is not None:
- ephem = self.ephem
- else:
+ if self.ephem is None:
log.warning(
"No ephemeris provided to TOAs object or compute_posvels. Using DE421"
)
ephem = "DE421"
- else:
- # If user specifies an ephemeris, make sure it is the same as the one already in
- # the TOA object, to prevent mixing.
- if (self.ephem is not None) and (ephem != self.ephem):
- log.error(
- "Ephemeris provided to compute_posvels {0} is different than "
- "TOAs object ephemeris {1}! Using posvels ephemeris.".format(
- ephem, self.ephem
- )
- )
+ else:
+ ephem = self.ephem
+ elif (self.ephem is not None) and (ephem != self.ephem):
+ log.error(
+ f"Ephemeris provided to compute_posvels {ephem} is different than TOAs object ephemeris {self.ephem}! Using posvels ephemeris."
+ )
if planets is None:
planets = self.planets
# Record the choice of ephemeris and planets
@@ -2319,25 +2335,21 @@ def compute_posvels(self, ephem=None, planets=None):
self.planets = planets
if planets:
log.debug(
- "Computing PosVels of observatories, Earth and planets, using {}".format(
- ephem
- )
+ f"Computing PosVels of observatories, Earth and planets, using {ephem}"
)
else:
- log.debug(
- "Computing PosVels of observatories and Earth, using {}".format(ephem)
- )
+ log.debug(f"Computing PosVels of observatories and Earth, using {ephem}")
# Remove any existing columns
cols_to_remove = ["ssb_obs_pos", "ssb_obs_vel", "obs_sun_pos"]
for c in cols_to_remove:
if c in self.table.colnames:
- log.debug("Column {0} already exists. Removing...".format(c))
+ log.debug(f"Column {c} already exists. Removing...")
self.table.remove_column(c)
for p in all_planets:
- name = "obs_" + p + "_pos"
+ name = f"obs_{p}_pos"
if name in self.table.colnames:
- log.debug("Column {0} already exists. Removing...".format(name))
+ log.debug(f"Column {name} already exists. Removing...")
self.table.remove_column(name)
self.table.meta["ephem"] = ephem
@@ -2362,7 +2374,7 @@ def compute_posvels(self, ephem=None, planets=None):
if planets:
plan_poss = {}
for p in all_planets:
- name = "obs_" + p + "_pos"
+ name = f"obs_{p}_pos"
plan_poss[name] = table.Column(
name=name,
data=np.zeros((self.ntoas, 3), dtype=np.float64),
@@ -2380,14 +2392,14 @@ def compute_posvels(self, ephem=None, planets=None):
else:
ssb_obs = site.posvel(tdb, ephem)
- log.debug("SSB obs pos {0}".format(ssb_obs.pos[:, 0]))
+ log.debug(f"SSB obs pos {ssb_obs.pos[:, 0]}")
ssb_obs_pos[grp, :] = ssb_obs.pos.T.to(u.km)
ssb_obs_vel[grp, :] = ssb_obs.vel.T.to(u.km / u.s)
sun_obs = objPosVel_wrt_SSB("sun", tdb, ephem) - ssb_obs
obs_sun_pos[grp, :] = sun_obs.pos.T.to(u.km)
if planets:
for p in all_planets:
- name = "obs_" + p + "_pos"
+ name = f"obs_{p}_pos"
dest = p
pv = objPosVel_wrt_SSB(dest, tdb, ephem) - ssb_obs
plan_poss[name][grp, :] = pv.pos.T.to(u.km)
@@ -2444,7 +2456,7 @@ def add_vel_ecl(self, obliquity):
# get velocity vector from coordinate frame
ssb_obs_vel_ecl[grp, :] = coord.velocity.d_xyz.T.to(u.km / u.s)
col = ssb_obs_vel_ecl
- log.debug("Adding column " + col.name)
+ log.debug(f"Adding column {col.name}")
self.table.add_column(col)
def update_mjd_float(self):
@@ -2560,7 +2572,7 @@ def merge(self, t, *args, strict=False):
# some data have pulse_numbers but not all
# put in NaN
for i, tt in enumerate(TOAs_list):
- if not "pulse_number" in tt.table.colnames:
+ if "pulse_number" not in tt.table.colnames:
log.warning(
f"'pulse_number' not present in data set {i}: inserting NaNs"
)
@@ -2575,11 +2587,11 @@ def merge(self, t, *args, strict=False):
else:
# some data have positions/velocities but not all
# compute as needed
- for i, tt in enumerate(TOAs_list):
- if not (
- ("ssb_obs_pos" in tt.table.colnames)
- and ("ssb_obs_vel" in tt.table.colnames)
- and ("obs_sun_pos" in tt.table.colnames)
+ for tt in TOAs_list:
+ if (
+ "ssb_obs_pos" not in tt.table.colnames
+ or "ssb_obs_vel" not in tt.table.colnames
+ or "obs_sun_pos" not in tt.table.colnames
):
tt.compute_posvels()
if has_posvel_ecl.any() and not has_posvel_ecl.all():
@@ -2590,8 +2602,8 @@ def merge(self, t, *args, strict=False):
else:
# some data have ecliptic positions/velocities but not all
# compute as needed
- for i, tt in enumerate(TOAs_list):
- if not (("ssb_obs_vel_ecl" in tt.table.colnames)):
+ for tt in TOAs_list:
+ if "ssb_obs_vel_ecl" not in tt.table.colnames:
tt.add_vel_ecl(obliquity[0])
if has_tdb.any() and not has_tdb.all():
if strict:
@@ -2599,9 +2611,10 @@ def merge(self, t, *args, strict=False):
else:
# some data have TDBs but not all
# compute as needed
- for i, tt in enumerate(TOAs_list):
- if not (
- ("tdb" in tt.table.colnames) and ("tdbld" in tt.table.colnames)
+ for tt in TOAs_list:
+ if (
+ "tdb" not in tt.table.colnames
+ or "tdbld" not in tt.table.colnames
):
tt.compute_TDBs()
@@ -2612,8 +2625,7 @@ def merge(self, t, *args, strict=False):
nt.filename = []
for xx in filenames:
if type(xx) is list:
- for yy in xx:
- nt.filename.append(yy)
+ nt.filename.extend(iter(xx))
else:
nt.filename.append(xx)
# We do not ensure that the command list is flat
@@ -2638,14 +2650,14 @@ def merge(self, t, *args, strict=False):
# but it should be more helpful
message = []
for i, colnames in enumerate(all_colnames[1:]):
- extra_columns = [x for x in colnames if not x in all_colnames[0]]
- missing_columns = [x for x in all_colnames[0] if not x in colnames]
- if len(extra_columns) > 0:
+ extra_columns = [x for x in colnames if x not in all_colnames[0]]
+ missing_columns = [x for x in all_colnames[0] if x not in colnames]
+ if extra_columns:
message.append(
f"File {i+1} has extra column(s): {','.join(extra_columns)}"
)
- if len(missing_columns) > 0:
+ if missing_columns:
message.append(
f"File {i+1} has missing column(s): {','.join(missing_columns)}"
)
@@ -2685,7 +2697,8 @@ def merge_TOAs(TOAs_list, strict=False):
"""
# don't duplicate code: just use the existing method
t = copy.deepcopy(TOAs_list[0])
- t.merge(*TOAs_list[1:], strict=strict)
+ if len(TOAs_list) > 1:
+ t.merge(*TOAs_list[1:], strict=strict)
return t
@@ -2706,6 +2719,7 @@ def get_TOAs_array(
commands=None,
hashes=None,
limits="warn",
+ tzr=False,
**kwargs,
):
"""Load and prepare TOAs for PINT use from an array of times.
@@ -2770,6 +2784,8 @@ def get_TOAs_array(
has changed so that the file can be re-read if necessary.
limits : "warn" or "error"
What to do when encountering TOAs for which clock corrections are not available.
+ tzr : bool
+ Whether the TOAs object corresponds to a TZR TOA
Returns
-------
@@ -2914,9 +2930,9 @@ def get_TOAs_array(
)
flagdicts = [FlagDict.from_dict(f) for f in flags]
elif flags is not None:
- flagdicts = [FlagDict(flags)] * len(t)
+ flagdicts = [FlagDict(flags) for i in range(len(t))]
else:
- flagdicts = [FlagDict()] * len(t)
+ flagdicts = [FlagDict() for i in range(len(t))]
for k, v in kwargs.items():
if isinstance(v, (list, tuple, np.ndarray)):
@@ -2956,7 +2972,7 @@ def get_TOAs_array(
"delta_pulse_number",
),
)
- t = TOAs(toatable=out)
+ t = TOAs(toatable=out, tzr=tzr)
t.commands = [] if commands is None else commands
t.hashes = {} if hashes is None else hashes
if all("clkcorr" not in f for f in t.table["flags"]):
diff --git a/src/pint/toa_select.py b/src/pint/toa_select.py
index 0dd1d117f..52d1c1c1a 100644
--- a/src/pint/toa_select.py
+++ b/src/pint/toa_select.py
@@ -81,26 +81,19 @@ def check_table_column(self, new_column):
False for column has been changed.
"""
if self.use_hash:
- if new_column.name not in self.hash_dict.keys():
- self.hash_dict[new_column.name] = hash(new_column.tobytes())
- return False
- else:
- if self.hash_dict[new_column.name] == hash(new_column.tobytes()):
- return True
- else:
- # update hash value to new column
- self.hash_dict[new_column.name] = hash(new_column.tobytes())
- return False
+ if new_column.name in self.hash_dict.keys() and self.hash_dict[
+ new_column.name
+ ] == hash(new_column.tobytes()):
+ return True
+ # update hash value to new column
+ self.hash_dict[new_column.name] = hash(new_column.tobytes())
+ elif new_column.name not in self.columns_info.keys():
+ self.columns_info[new_column.name] = new_column
+ elif np.array_equal(self.columns_info[new_column.name], new_column):
+ return True
else:
- if new_column.name not in self.columns_info.keys():
- self.columns_info[new_column.name] = new_column
- return False
- else:
- if np.array_equal(self.columns_info[new_column.name], new_column):
- return True
- else:
- self.columns_info[new_column.name] = new_column
- return False
+ self.columns_info[new_column.name] = new_column
+ return False
def get_select_range(self, condition, column):
"""
@@ -125,9 +118,7 @@ def get_select_non_range(self, condition, column):
def get_select_index(self, condition, column):
# Check if condition get changed
cd_unchg, cd_chg = self.check_condition(condition)
- # check if column get changed.
- col_change = self.check_table_column(column)
- if col_change:
+ if col_change := self.check_table_column(column):
if self.is_range:
new_select = self.get_select_range(cd_chg, column)
else:
diff --git a/src/pint/utils.py b/src/pint/utils.py
index 59eb29e79..8e6ca28ab 100644
--- a/src/pint/utils.py
+++ b/src/pint/utils.py
@@ -39,6 +39,7 @@
import textwrap
from contextlib import contextmanager
from pathlib import Path
+import uncertainties
import astropy.constants as const
import astropy.coordinates as coords
@@ -87,6 +88,7 @@
"require_longdouble_precision",
"get_conjunction",
"divide_times",
+ "get_unit",
]
COLOR_NAMES = ["black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"]
@@ -346,7 +348,7 @@ def split_prefixed_name(name):
except AttributeError:
continue
else:
- raise PrefixError("Unrecognized prefix name pattern '%s'." % name)
+ raise PrefixError(f"Unrecognized prefix name pattern '{name}'.")
return prefix_part, index_part, int(index_part)
@@ -361,9 +363,9 @@ def taylor_horner(x, coeffs):
Parameters
----------
- x: astropy.units.Quantity
+ x: float or numpy.ndarray or astropy.units.Quantity
Input value; may be an array.
- coeffs: list of astropy.units.Quantity
+ coeffs: list of astropy.units.Quantity or uncertainties.ufloat
Coefficient array; must have length at least one. The coefficient in
position ``i`` is multiplied by ``x**i``. Each coefficient should
just be a number, not an array. The units should be compatible once
@@ -371,7 +373,7 @@ def taylor_horner(x, coeffs):
Returns
-------
- astropy.units.Quantity
+ float or numpy.ndarray or astropy.units.Quantity
Output value; same shape as input. Units as inferred from inputs.
"""
return taylor_horner_deriv(x, coeffs, deriv_order=0)
@@ -388,9 +390,9 @@ def taylor_horner_deriv(x, coeffs, deriv_order=1):
Parameters
----------
- x: astropy.units.Quantity
+ x: float or numpy.ndarray or astropy.units.Quantity
Input value; may be an array.
- coeffs: list of astropy.units.Quantity
+ coeffs: list of astropy.units.Quantity or uncertainties.ufloat
Coefficient array; must have length at least one. The coefficient in
position ``i`` is multiplied by ``x**i``. Each coefficient should
just be a number, not an array. The units should be compatible once
@@ -401,9 +403,10 @@ def taylor_horner_deriv(x, coeffs, deriv_order=1):
Returns
-------
- astropy.units.Quantity
+ float or numpy.ndarray or astropy.units.Quantity
Output value; same shape as input. Units as inferred from inputs.
"""
+ assert deriv_order >= 0
result = 0.0
if hasattr(coeffs[-1], "unit"):
if not hasattr(x, "unit"):
@@ -859,7 +862,7 @@ def dmxstats(model, toas, file=sys.stdout):
"""
mjds = toas.get_mjds()
freqs = toas.table["freq"]
- selected = np.zeros(len(toas), dtype=np.bool8)
+ selected = np.zeros(len(toas), dtype=np.bool_)
DMX_mapping = model.get_prefix_mapping("DMX_")
select_idx = dmxselections(model, toas)
for ii in DMX_mapping:
@@ -937,7 +940,7 @@ def dmxparse(fitter, save=False):
DMX_Errs = np.zeros(len(dmx_epochs))
DMX_R1 = np.zeros(len(dmx_epochs))
DMX_R2 = np.zeros(len(dmx_epochs))
- mask_idxs = np.zeros(len(dmx_epochs), dtype=np.bool8)
+ mask_idxs = np.zeros(len(dmx_epochs), dtype=np.bool_)
# Get DMX values (will be in units of 10^-3 pc cm^-3)
for ii, epoch in enumerate(dmx_epochs):
DMXs[ii] = getattr(fitter.model, "DMX_{:}".format(epoch)).value
@@ -1330,8 +1333,10 @@ def ELL1_check(
Checks whether the assumptions that allow ELL1 to be safely used are
satisfied. To work properly, we should have:
- :math:`asini/c e^2 \ll {\\rm timing precision} / \sqrt N_{\\rm TOA}`
- or :math:`A1 E^2 \ll TRES / \sqrt N_{\\rm TOA}`
+ :math:`asini/c e^4 \ll {\\rm timing precision} / \sqrt N_{\\rm TOA}`
+ or :math:`A1 E^4 \ll TRES / \sqrt N_{\\rm TOA}`
+
+ since the ELL1 model now includes terms up to O(E^3)
Parameters
----------
@@ -1352,12 +1357,12 @@ def ELL1_check(
If outstring is True then returns a string summary instead.
"""
- lhs = A1 / const.c * E**2.0
+ lhs = A1 / const.c * E**4.0
rhs = TRES / np.sqrt(NTOA)
if outstring:
s = "Checking applicability of ELL1 model -- \n"
- s += " Condition is asini/c * ecc**2 << timing precision / sqrt(# TOAs) to use ELL1\n"
- s += " asini/c * ecc**2 = {:.3g} \n".format(lhs.to(u.us))
+ s += " Condition is asini/c * ecc**4 << timing precision / sqrt(# TOAs) to use ELL1\n"
+ s += " asini/c * ecc**4 = {:.3g} \n".format(lhs.to(u.us))
s += " TRES / sqrt(# TOAs) = {:.3g} \n".format(rhs.to(u.us))
if lhs * 50.0 < rhs:
if outstring:
@@ -1577,7 +1582,7 @@ def remove_dummy_distance(c):
return c
-def info_string(prefix_string="# ", comment=None):
+def info_string(prefix_string="# ", comment=None, detailed=False):
"""Returns an informative string about the current state of PINT.
Adds:
@@ -1597,6 +1602,8 @@ def info_string(prefix_string="# ", comment=None):
comment or similar)
comment: str, optional
a free-form comment string to be included if present
+ detailed: bool, optional
+ Include detailed version info on dependencies.
Returns
-------
@@ -1687,13 +1694,61 @@ def info_string(prefix_string="# ", comment=None):
except (configparser.NoOptionError, configparser.NoSectionError, ImportError):
username = getpass.getuser()
- s = f"""
- Created: {datetime.datetime.now().isoformat()}
- PINT_version: {pint.__version__}
- User: {username}
- Host: {platform.node()}
- OS: {platform.platform()}
- """
+ info_dict = {
+ "Created": f"{datetime.datetime.now().isoformat()}",
+ "PINT_version": pint.__version__,
+ "User": username,
+ "Host": platform.node(),
+ "OS": platform.platform(),
+ "Python": sys.version,
+ }
+
+ if detailed:
+ from numpy import __version__ as numpy_version
+ from scipy import __version__ as scipy_version
+ from astropy import __version__ as astropy_version
+ from erfa import __version__ as erfa_version
+ from jplephem import __version__ as jpleph_version
+ from matplotlib import __version__ as matplotlib_version
+ from loguru import __version__ as loguru_version
+ from pint import __file__ as pint_file
+
+ info_dict.update(
+ {
+ "endian": sys.byteorder,
+ "numpy_version": numpy_version,
+ "numpy_longdouble_precision": np.dtype(np.longdouble).name,
+ "scipy_version": scipy_version,
+ "astropy_version": astropy_version,
+ "pyerfa_version": erfa_version,
+ "jplephem_version": jpleph_version,
+ "matplotlib_version": matplotlib_version,
+ "loguru_version": loguru_version,
+ "Python_prefix": sys.prefix,
+ "PINT_file": pint_file,
+ }
+ )
+
+ if "CONDA_PREFIX" in os.environ:
+ conda_prefix = os.environ["CONDA_PREFIX"]
+ info_dict.update(
+ {
+ "Environment": "conda",
+ "conda_prefix": conda_prefix,
+ }
+ )
+ elif "VIRTUAL_ENV" in os.environ:
+ venv_prefix = os.environ["VIRTUAL_ENV"]
+ info_dict.update(
+ {
+ "Environment": "virtualenv",
+ "virtualenv_prefix": venv_prefix,
+ }
+ )
+
+ s = ""
+ for key, val in info_dict.items():
+ s += f"{key}: {val}\n"
s = textwrap.dedent(s)
# remove blank lines
@@ -2018,3 +2073,59 @@ def convert_dispersion_measure(dm, dmconst=None):
me = constants.m_e.si
dmconst = e**2 / (8 * np.pi**2 * c * eps0 * me)
return (dm * pint.DMconst / dmconst).to(pint.dmu)
+
+
+def parse_time(input, scale="tdb", precision=9):
+ """Parse an :class:`astropy.time.Time` object from a range of input types
+
+ Parameters
+ ----------
+ input : astropy.time.Time, astropy.units.Quantity, numpy.ndarray, float, int, str
+ Value to parse
+ scale : str, optional
+ Scale of time for conversion
+ precision : int, optional
+ Precision for time
+
+ Returns
+ -------
+ astropy.time.Time
+ """
+ if isinstance(input, Time):
+ return input if input.scale == scale else getattr(input, scale)
+ elif isinstance(input, u.Quantity):
+ return Time(
+ input.to(u.d), format="pulsar_mjd", scale=scale, precision=precision
+ )
+ elif isinstance(input, (np.ndarray, float, int)):
+ return Time(input, format="pulsar_mjd", scale=scale, precision=precision)
+ elif isinstance(input, str):
+ return Time(input, format="pulsar_mjd_string", scale=scale, precision=precision)
+ else:
+ raise TypeError(f"Do not know how to parse times from {type(input)}")
+
+
+def get_unit(parname):
+ """Return the unit associated with a parameter
+
+ Handles normal parameters, along with aliases and indexed parameters
+ (e.g., `pint.models.parameter.prefixParameter`
+ and `pint.models.parameter.maskParameter`) with an index beyond those currently
+ initialized.
+
+ This can be used without an existing :class:`~pint.models.TimingModel`.
+
+ Parameters
+ ----------
+ name : str
+ Name of PINT parameter or alias
+
+ Returns
+ -------
+ astropy.u.Unit
+ """
+ # import in the function to avoid circular dependencies
+ from pint.models.timing_model import AllComponents
+
+ ac = AllComponents()
+ return ac.param_to_unit(parname)
diff --git a/tempo2Test/T2spiceTest.py b/tempo2Test/T2spiceTest.py
index f548f62d4..9403d539e 100644
--- a/tempo2Test/T2spiceTest.py
+++ b/tempo2Test/T2spiceTest.py
@@ -1,5 +1,3 @@
-from __future__ import print_function, division
-
# import matplotlib
# matplotlib.use('TKAgg')
import matplotlib.pyplot as plt
@@ -20,64 +18,52 @@ def mjd2et(mjd, tt2tdb):
mjdJ2000 = mp.mpf("51544.5")
secDay = mp.mpf("86400.0")
mjdTT = tc.mjd2tdt(mp.mpf(mjd))
- # Convert mjdutc to mjdtdt using HP time convert lib
- # print "python ",mjdTT
- et = (mp.mpf(mjdTT) - mjdJ2000) * mp.mpf(86400.0) + mp.mpf(tt2tdb)
- return et
+ return (mp.mpf(mjdTT) - mjdJ2000) * mp.mpf(86400.0) + mp.mpf(tt2tdb)
#### Read tempo2 tim file
fname1 = "J0000+0000.tim"
-fp1 = open(fname1, "r")
-
-toa = []
-# Read TOA column to toa array
-for l in fp1.readlines():
- l = l.strip()
- l = l.strip("\n")
- l = l.split()
- if len(l) > 3:
- toa.append(l[2])
+with open(fname1, "r") as fp1:
+ toa = []
+ # Read TOA column to toa array
+ for l in fp1:
+ l = l.strip()
+ l = l.strip("\n")
+ l = l.split()
+ if len(l) > 3:
+ toa.append(l[2])
#### Read tempo2 general2 output file
fname = "T2output.dat"
-fp = open(fname, "r")
-
-tt2tdb = [] # Tempo2 tt2tdb difference in (sec)
-earth1 = [] # Tempo2 earth position in (light time, sec)
-earth2 = []
-earth3 = []
-# Read tt2tdb earthposition output
-for l in fp.readlines():
- l = l.strip()
- l = l.strip("\n")
- l = l.split()
- # Avoid the column that is not data
- try:
- m = float(l[0])
- except:
- pass
- else:
- tt2tdb.append(l[-1])
- earth1.append(l[0])
- earth2.append(l[1])
- earth3.append(l[2])
-#### Testing toa mjd to tt
-tt = []
-for i in range(len(toa)):
- tt.append(tc.mjd2tdt(mp.mpf(toa[i])))
-
-# Testing Convert toa mjd to toa et
-et = []
+with open(fname, "r") as fp:
+ tt2tdb = [] # Tempo2 tt2tdb difference in (sec)
+ earth1 = [] # Tempo2 earth position in (light time, sec)
+ earth2 = []
+ earth3 = []
+ # Read tt2tdb earth position output
+ for l in fp:
+ l = l.strip()
+ l = l.strip("\n")
+ l = l.split()
+ # Avoid the column that is not data
+ try:
+ m = float(l[0])
+ except Exception:
+ pass
+ else:
+ tt2tdb.append(l[-1])
+ earth1.append(l[0])
+ earth2.append(l[1])
+ earth3.append(l[2])
-for i in range(len(toa)):
- et.append(mjd2et(toa[i], tt2tdb[i]))
+tt = [tc.mjd2tdt(mp.mpf(item)) for item in toa]
+et = [mjd2et(toa[i], tt2tdb[i]) for i in range(len(toa))]
###### calculate earth position
stateInterp = [] # interpolated earth position in (km)
ltInterp = [] # interpolated earth to ssb light time in (sec)
-statespk = [] # Directlt calculated earth position in (km)
+statespk = [] # Directly calculated earth position in (km)
ltspk = [] # Directly calculated earth to ssb lt time in (sec)
-# Calculating postion
+# Calculating position
for time in et:
(
state0,
@@ -108,7 +94,7 @@ def mjd2et(mjd, tt2tdb):
stateInterpy.append(mp.mpf(stateInterp[i][1]))
statespky.append(mp.mpf(statespk[i][1]))
diff1.append(stateInterpy[i] - statespky[i])
- # Difference between interploated position and tempo2 postion out put in KM
+ # Difference between interpolated position and tempo2 position out put in KM
diff2.append(stateInterpy[i] - mp.mpf(earth2[i]) * mp.mpf(spice.clight()))
plt.figure(1)
diff --git a/tempo2Test/spiceTest.py b/tempo2Test/spiceTest.py
index 18139869d..5213ae304 100644
--- a/tempo2Test/spiceTest.py
+++ b/tempo2Test/spiceTest.py
@@ -1,11 +1,10 @@
-from __future__ import print_function, division
import spice
import numpy as np
def test_lmt(et, step):
"""
- Testing how accurate that spice can distinguish two near by times(et) with littl time step
+ Testing how accurate that spice can distinguish two near by times(et) with little time step
et is the initial time
step is the small time step
"""
@@ -44,12 +43,10 @@ def spice_Intplt(et, stepBitNum):
print(exctNum0, exctNum1)
state0, lt0 = spice.spkezr("EARTH", exctNum0, "J2000", "NONE", "SSB")
state1, lt1 = spice.spkezr("EARTH", exctNum1, "J2000", "NONE", "SSB")
- state = []
- lt = []
- for i in range(6):
- state.append(np.interp(et, [exctNum0, exctNum1], [state0[i], state1[i]]))
- lt.append(np.interp(et, [exctNum0, exctNum1], [lt0, lt1]))
-
+ state = [
+ np.interp(et, [exctNum0, exctNum1], [state0[i], state1[i]]) for i in range(6)
+ ]
+ lt = [np.interp(et, [exctNum0, exctNum1], [lt0, lt1])]
stateOr, ltOr = spice.spkezr("EARTH", et, "J2000", "NONE", "SSB")
return state, stateOr, np.array(state) - np.array(stateOr)
@@ -57,7 +54,7 @@ def spice_Intplt(et, stepBitNum):
def spkInterp(et, stepBitNum):
"""
- This function interpolates earth state in one second with seveal exact points.
+ This function interpolates earth state in one second with several exact points.
To increase accuracy, each know point will be the exact number that can be represented
by double precision.
et is the target time
diff --git a/tempo2Test/tt2tdbT2.py b/tempo2Test/tt2tdbT2.py
index 44ff4ba47..236ab0842 100644
--- a/tempo2Test/tt2tdbT2.py
+++ b/tempo2Test/tt2tdbT2.py
@@ -89,8 +89,7 @@ def mjd2tdt(mjd):
delta_TT = mp.mpf(dt) + 32.184
delta_TT_DAY = mp.mpf(delta_TT) / mp.mpf(86400.0)
delta_TT_DAY = mp.mpf(delta_TT_DAY)
- mjd_tt = mp.mpf(mjd) + delta_TT_DAY
- return mjd_tt
+ return mp.mpf(mjd) + delta_TT_DAY
#### From mjdutc to et
@@ -101,8 +100,7 @@ def mjd2et(mjd, tt2tdb):
mjdJ2000 = mp.mpf("51544.5")
secDay = mp.mpf("86400.0")
mjdTT = mjd2tdt(mjd)
- et = (mjdTT - mjdJ2000) * secDay + mp.mpf(tt2tdb)
- return et
+ return (mjdTT - mjdJ2000) * secDay + mp.mpf(tt2tdb)
#### Read tempo2 tim file
@@ -111,7 +109,7 @@ def mjd2et(mjd, tt2tdb):
toa = []
# Read TOA column to toa array
-for l in fp1.readlines():
+for l in fp1:
l = l.strip()
l = l.strip("\n")
l = l.split()
@@ -126,15 +124,15 @@ def mjd2et(mjd, tt2tdb):
earth1 = [] # Tempo2 earth position in (light time, sec)
earth2 = []
earth3 = []
-# Read tt2tdb earthposition output
-for l in fp.readlines():
+# Read tt2tdb earth position output
+for l in fp:
l = l.strip()
l = l.strip("\n")
l = l.split()
# Avoid the column that is not data
try:
m = float(l[0])
- except:
+ except Exception:
pass
else:
tt2tdb.append(l[-1])
@@ -142,17 +140,13 @@ def mjd2et(mjd, tt2tdb):
earth2.append(l[1])
earth3.append(l[2])
-et = []
-# Convert toa mjd to toa et
-for i in range(len(toa)):
- et.append(mjd2et(toa[i], tt2tdb[i]))
-
+et = [mjd2et(toa[i], tt2tdb[i]) for i in range(len(toa))]
###### calculate earth position
stateInterp = [] # interpolated earth position in (km)
ltInterp = [] # interpolated earth to ssb light time in (sec)
-statespk = [] # Directlt calculated earth position in (km)
+statespk = [] # Directly calculated earth position in (km)
ltspk = [] # Directly calculated earth to ssb lt time in (sec)
-# Calculating postion
+# Calculating position
for time in et:
(
state0,
diff --git a/tests/datafile/observatory/aliases b/tests/datafile/observatory/aliases
new file mode 100644
index 000000000..54f2e958f
--- /dev/null
+++ b/tests/datafile/observatory/aliases
@@ -0,0 +1,38 @@
+# This file is a list of aliases for observatories.
+# Each line lists the aliases for a single observatory, the short code for
+# which is the first word on the line. One or more aliases may be listed for
+# each observatory.
+#
+# These aliases differ from the tempo aliases because there are too many
+# observatories and different sites were added differently at times.
+#
+# If you need to use tempo site codes, try setting TEMPO2_ALIAS="tempo"
+gbt 1 gb
+atca 2
+ao 3 arecebo arecibo
+hobart 4
+nanshan 5
+tid43 6
+pks 7
+jb 8 y # z # this is used by srt
+vla c
+ncy f
+eff g
+jbdfb q
+wsrt i
+# jb42 i # This is used by wsrt now
+lofar t
+ncyobs w
+lwa1 x
+srt z
+jbmk2 h
+jbmk3 j
+meerkat m
+gmrt r
+# eMerlin sites
+tabley k
+darnhall l
+knockin k
+defford n
+# Other single char codes in observatories.dat
+lapalma p
diff --git a/tests/datafile/observatory/observatories.dat b/tests/datafile/observatory/observatories.dat
new file mode 100644
index 000000000..3a27f0ad7
--- /dev/null
+++ b/tests/datafile/observatory/observatories.dat
@@ -0,0 +1,157 @@
+# New format observatory information file.
+#
+# Values are x,y,z geocentric coordinates
+# The last column contains the old-style telescope ID code
+#
+ 882589.289 -4924872.368 3943729.418 GBT gbt
+-4752329.7000 2790505.9340 -3200483.7470 NARRABRI atca
+ 2390487.080 -5564731.357 1994720.633 ARECIBO ao
+ 228310.702 4631922.905 4367064.059 NANSHAN NS
+ 228310.702 4631922.905 4367064.059 UAO NS
+ -4460892.6 2682358.9 -3674756.0 DSS_43 tid43
+ -4554231.5 2816759.1 -3454036.3 PARKES pks
+ 4865182.7660 791922.6890 4035137.1740 SRT srt
+ -1601192. -5041981.4 3554871.4 VLA vla
+ 4324165.81 165927.11 4670132.83 NANCAY ncy
+ 4324165.81 165927.11 4670132.83 NUPPI ncyobs
+ 4324165.81 165927.11 4670132.83 OP obspm
+ 4033949.5 486989.4 4900430.8 EFFELSBERG eff
+ 4033949.5 486989.4 4900430.8 EFFELSBERG_ASTERIX effix
+ 4033949.5 486989.4 4900430.8 LEAP leap
+ 3822252.643 -153995.683 5086051.443 JODRELLM4 jbm4
+ 881856.58 -4925311.86 3943459.70 GB300 gb300
+ 882872.57 -4924552.73 3944154.92 GB140 gb140
+ 882315.33 -4925191.41 3943414.05 GB853 gb853
+ 883775.18 -4924398.84 3944052.95 NRAO20 gb20
+ 5327021.651 -1719555.576 3051967.932 LA_PALMA lap
+ -3950077.96 2522377.31 -4311667.52 Hobart hob
+ 5085442.780 2668263.483 -2768697.034 Hartebeesthoek hart
+ 3828445.659 445223.600000 5064921.5677 WSRT wsrt
+ -5115425.60 477880.31 -3767042.81 Warkworth_30m wark30m
+ -5115324.399 477843.305 -3767192.886 Warkworth_12m wark12m
+ 3826577.462 461022.624 5064892.526 LOFAR lofar
+ 4034038.635 487026.223 4900280.057 DE601LBA EFlfrlba
+ 4034038.635 487026.223 4900280.057 DE601LBH EFlfrlbh
+ 4034101.901 487012.401 4900230.210 DE601HBA EFlfrhba
+ 4034101.901 487012.401 4900230.210 DE601 EFlfr
+ 4152561.068 828868.725 4754356.878 DE602LBA UWlfrlba
+ 4152561.068 828868.725 4754356.878 DE602LBH UWlfrlbh
+ 4152568.416 828788.802 4754361.926 DE602HBA UWlfrhba
+ 4152568.416 828788.802 4754361.926 DE602 UWlfr
+ 3940285.328 816802.001 4932392.757 DE603LBA TBlfrlba
+ 3940285.328 816802.001 4932392.757 DE603LBH TBlfrlbh
+ 3940296.126 816722.532 4932394.152 DE603HBA TBlfrhba
+ 3940296.126 816722.532 4932394.152 DE603 TBlfr
+ 3796327.609 877591.315 5032757.252 DE604LBA POlfrlba
+ 3796327.609 877591.315 5032757.252 DE604LBH POlfrlbh
+ 3796380.254 877613.809 5032712.272 DE604HBA POlfrhba
+ 3796380.254 877613.809 5032712.272 DE604 POlfr
+ 4005681.742 450968.282 4926457.670 DE605LBA JUlfrlba
+ 4005681.742 450968.282 4926457.670 DE605LBH JUlfrlbh
+ 4005681.407 450968.304 4926457.940 DE605HBA JUlfrhba
+ 4005681.407 450968.304 4926457.940 DE605 JUlfr
+ 4323980.155 165608.408 4670302.803 FR606LBA FRlfrlba
+ 4323980.155 165608.408 4670302.803 FR606LBH FRlfrlbh
+ 4324017.054 165545.160 4670271.072 FR606HBA FRlfrhba
+ 4324017.054 165545.160 4670271.072 FR606 FRlfr
+ 3370287.366 712053.586 5349991.228 SE607LBA ONlfrlba
+ 3370287.366 712053.586 5349991.228 SE607LBH ONlfrlbh
+ 3370272.092 712125.596 5349990.934 SE607HBA ONlfrhba
+ 3370272.092 712125.596 5349990.934 SE607 ONlfr
+ 4008438.796 -100310.064 4943735.554 UK608LBA UKlfrlba
+ 4008438.796 -100310.064 4943735.554 UK608LBH UKlfrlbh
+ 4008462.280 -100376.948 4943716.600 UK608HBA UKlfrhba
+ 4008462.280 -100376.948 4943716.600 UK608 UKlfr
+ 3727207.778 655184.900 5117000.625 DE609LBA NDlfrlba
+ 3727207.778 655184.900 5117000.625 DE609LBH NDlfrlbh
+ 3727218.128 655108.821 5117002.847 DE609HBA NDlfrhba
+ 3727218.128 655108.821 5117002.847 DE609 NDlfr
+ 2136833.225 810088.740 5935285.279 FI609LBA Filfrlba
+ 2136833.225 810088.740 5935285.279 FI609LBH Filfrlbh
+ 2136819.1940 810039.5757 5935299.0536 FI609HBA Filfrhba
+ 2136819.1940 810039.5757 5935299.0536 FI609 Filfr
+ 3307865.236 2487350.541 4836939.784 UTR-2 UTR2
+ 1656342.30 5797947.77 2073243.16 GMRT gmrt
+ -2353621.22 -4641341.52 3677052.352 GOLDSTONE gs
+ -4483311.64 2648815.92 -3671909.31 MOST mo
+ -2826711.951 4679231.627 3274665.675 SHAO shao
+ 5088964.00 301689.80 3825017.0 PICO_VELETA pv
+ -2524263.18 -4123529.78 4147966.36 ATA ata
+
+
+# IAR
+ 2765357.08 -4449628.98 -3625726.47 IAR1 iar1
+ 2765322.49 -4449569.52 -3625825.14 IAR2 iar2
+#
+
+# FAST
+-1666460.00 5499910.00 2759950.00 FAST fast
+#
+
+# Murchison Widefield Array (MWA)
+-2559454.08 5095372.14 -2849057.18 MWA mwa
+
+# MeerKAT
+#
+5109943.1050 2003650.7359 -3239908.3195 KAT-7 k7
+5109360.133 2006852.586 -3238948.127 MEERKAT meerkat
+
+# eMerlin telescopes
+####### From Jodrell obsys.dat file
+#
+ 383395.727 -173759.585 5077751.313 MKIII jbmk3
+ 3817176.557 -162921.170 5089462.046 TABLEY tabley
+ 3828714.504 -169458.987 5080647.749 DARNHALL darnhall
+ 3859711.492 -201995.082 5056134.285 KNOCKIN knockin
+ 3923069.135 -146804.404 5009320.570 DEFFORD defford
+ 3919982.752 2651.9820 5013849.826 CAMBRIDGE cam
+ 0.0 1.0 0.0 COE coe
+
+# MJK Approximate location for Princeton observatory
+ 1288748.38 -4694221.77 4107418.80 PRINCETON princeton
+ 3788815.62 1131748.336 5035101.190 HAMBURG hamburg
+#
+###### Telescope ID changed from the Jodrell obsys.dat file
+# Jodrell Bank Telescopes and Backends
+# JBO generic... use only for clock corrections
+ 3822626.04 -154105.65 5086486.04 JODRELL jb
+# Lovell Digital Filterbank
+ 3822626.04 -154105.65 5086486.04 JBODFB jbdfb
+# Lovell ROACH coherent dedisperser
+ 3822626.04 -154105.65 5086486.04 JBOROACH jbroach
+# Lovell Analogue Filterbank
+ 3822626.04 -154105.65 5086486.04 JBOAFB jbafb
+# 42ft telescope
+ 3822294.825 -153862.275 5085987.071 JB_42ft jb42
+
+# MkII Position from VLBI (MJK 2018)
+# X= 3822846.7600 Y= -153802.2800 Z= 5086285.9000
+# MkII telescope Analogue Filterbank
+ 3822846.7600 -153802.2800 5086285.9000 JB_MKII jbmk2
+# MkII ROACH coherent dedisperser
+ 3822846.7600 -153802.2800 5086285.9000 JB_MKII_RCH jbmk2roach
+# MkII Digital Filterbank
+ 3822846.7600 -153802.2800 5086285.9000 JB_MKII_DFB jbmk2dfb
+
+#
+# New telescopes
+ 1719555.576 5327021.651 3051967.932 LA_PALMA lapalma
+ -1602196.60 -5042313.47 3553971.51 LWA1 lwa1
+-1531155.54418 -5045324.30517 3579583.89450 LWA_SV lwasv
+
+ 6346273.5310 -33779.7127 634844.9454 GRAO grao
+
+ -2059166.313 -3621302.972 4814304.113 CHIME chime
+
+
+# groud-based gravitational wave observatories
+ 4546374.0990 842989.6976 4378576.9624 VIRGO virgo
+ -2161414.9264 -3834695.1789 4600350.2266 HANFORD lho
+ -74276.0447 -5496283.7197 3224257.0174 LIVINGSTON llo
+ 3856309.9493 666598.9563 5019641.4172 GEO600 geo600
+ -3777336.0240 3484898.411 3765313.6970 KAGRA kagra
+
+# Fake telescope for IPTA data challenge
+ 6378138.00 0.0 0.0 AXIS axi
+
+ 0 0 0 STL_BAT STL_BAT
diff --git a/tests/datafile/observatory/obsys.dat b/tests/datafile/observatory/obsys.dat
new file mode 100644
index 000000000..0797c42b6
--- /dev/null
+++ b/tests/datafile/observatory/obsys.dat
@@ -0,0 +1,30 @@
+ 882589.65 -4924872.32 3943729.348 1 GBT XYZ 1 GB
+ 422333.2 722040.4 306. QUABBIN 2 QU
+ 2390490.0 -5564764.0 1994727.0 1 ARECIBO XYZ (JPL) 3 AO
+ -424818.0 -1472621.0 50.0 Hobart, Tasmania 4 HO
+ 402047.7 743853.85 43. PRINCETON 5 PR
+ -1601192. -5041981.4 3554871.4 1 VLA XYZ 6 VL
+ -4554231.5 2816759.1 -3454036.3 1 PARKES XYZ (JER) 7 PK
+ 3822626.04 -154105.65 5086486.04 1 JODRELL BANK MkIA 8 JB
+ 382546.30 795056.36 893.7 GB 300FT 9 G3
+ 382615.409 795009.613 880.87 GB 140FT a G1
+ 382545.90 795036.87 835.8 GB 85-3 b G8
+ 340443.497 1073703.819 2124. VLA SITE c V2
+ 443118.48 -113848.48 25. NORTHERN CROSS d BO
+-4483311.64 2648815.92 -3671909.31 1 MOST e MO
+ 4324165.81 165927.11 4670132.83 1 Nancay f NC
+ 4033949.5 486989.4 4900430.8 1 Effelsberg g EF
+ 3822846.7600 -153802.2800 5086285.9000 1 JB_MKII h J2
+ 3828445.659 445223.600 5064921.5677 1 WSRT i WS
+-1668557.0 5506838.0 2744934.0 1 FAST k FA
+ 5109360.133 2006852.586 -3238948.127 1 MEERKAT m MK
+ 1656342.30 5797947.77 2073243.16 1 GMRT r GM
+-2826711.951 4679231.627 3274665.675 1 SHAO 65m XYZ s SH
+ 3826577.462 461022.624 5064892.526 1 LOFAR t LF
+-2559454.08 5095372.14 -2849057.18 1 MWA u MW
+ 5088964.00 301689.80 3825017.0 1 PICO VELETA v PV
+-1602206.58909 -5042244.28890 3554076.31847 1 LWA1 x LW
+-2058795.0 -3621559.0 4814280.0 1 CHIME (real) y CH
+ 4865182.7660 791922.6890 4035137.1740 1 SRT z SR
+-1531155.54418 -5045324.30517 3579583.89450 1 LWA-SV - LS
+ 883772.79740 -4924385.59750 3944042.49910 1 GB 20m XYZ - G2
\ No newline at end of file
diff --git a/tests/datafile/observatory/oldcodes.dat b/tests/datafile/observatory/oldcodes.dat
new file mode 100644
index 000000000..d54462629
--- /dev/null
+++ b/tests/datafile/observatory/oldcodes.dat
@@ -0,0 +1,22 @@
+# New format observatory information file.
+#
+# Values are x,y,z geocentric coordinates
+# The last column contains the old-style telescope ID code
+#
+####### From Jodrell obsys.dat file
+#
+ 383395.727 -173759.585 5077751.313 MKIII j
+ 3817176.557 -162921.170 5089462.046 TABLEY k
+ 3828714.504 -169458.987 5080647.749 DARNHALL l
+ 3859711.492 -201995.082 5056134.285 KNOCKIN m
+ 3923069.135 -146804.404 5009320.570 DEFFORD n
+ 0.0 1.0 0.0 COE o UTC
+#
+###### Telescope ID changed from the Jodrell obsys.dat file
+#
+ 3822473.365 -153692.318 5085851.303 JB_MKII h UTC
+ 3822294.825 -153862.275 5085987.071 JB_42ft i UTC
+
+# to handle Kaspi, Ryba & Taylor data, which is referred to UTC
+ 2390490.0 -5564764.0 1994727.0 ARECIBO aoutc UTC
+
diff --git a/tests/datafile/observatory/tempo.aliases b/tests/datafile/observatory/tempo.aliases
new file mode 100644
index 000000000..cfa391c2e
--- /dev/null
+++ b/tests/datafile/observatory/tempo.aliases
@@ -0,0 +1,32 @@
+# This alias file tries to replicate obsys.dat from tempo
+# To use this file by default, set env variable
+# $TEMPO2_ALIAS=tempo
+#
+gbt 1 GB
+quabbin 2 QU
+ao 3 arecebo arecibo
+hobart 4 HO
+princeton 5 PR
+vla 6 VL c
+pks 7 PK
+jb 8 JB
+gb300 9 G3
+gb140 a G1
+gb853 b G8
+medicina d BO
+mo e
+ncy f NC
+eff g EF
+jbmk2 h J2
+wsrt i WS
+fast k FA
+meerkat m MK
+gmrt r GM
+shao s SH
+lofar t LF
+mwa u MW
+pv v
+lwa1 x LW
+chime y CH
+srt z SR
+lwasv - LS
diff --git a/tests/pinttestdata.py b/tests/pinttestdata.py
index b0db5ea34..a9777bf78 100644
--- a/tests/pinttestdata.py
+++ b/tests/pinttestdata.py
@@ -1,6 +1,4 @@
-# pinttestdata.py
-
-# import this to get the location of the datafiles for tests. This file
+# Import this to get the location of the datafiles for tests. This file
# must be kept in the appropriate location relative to the test data dir
# for this to work.
diff --git a/tests/simulate_FD_model.py b/tests/simulate_FD_model.py
index 6d2fada05..4ee18f614 100644
--- a/tests/simulate_FD_model.py
+++ b/tests/simulate_FD_model.py
@@ -1,4 +1,3 @@
-#!/usr/bin/python
import sys
import astropy.time as time
@@ -46,5 +45,5 @@ def add_FD_model(freq_range, FD, toas):
]
add_FD_model(freqrange, fdcoeff, t)
- outfile = timfile + ".pint_simulate"
+ outfile = f"{timfile}.pint_simulate"
t.write_TOA_file(outfile, format="TEMPO2")
diff --git a/tests/test_B1855.py b/tests/test_B1855.py
index 20a594c0d..e1248e755 100644
--- a/tests/test_B1855.py
+++ b/tests/test_B1855.py
@@ -1,7 +1,7 @@
"""Various tests to assess the performance of the B1855+09."""
import logging
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -14,11 +14,11 @@
from pinttestdata import datadir
-class TestB1855(unittest.TestCase):
+class TestB1855:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parfileB1855 = "B1855+09_NANOGrav_dfg+12_TAI_FB90.par"
cls.timB1855 = "B1855+09_NANOGrav_dfg+12.tim"
@@ -29,14 +29,14 @@ def setUpClass(cls):
except IOError:
pytest.skip("Unable to fetch ephemeris")
cls.modelB1855 = mb.get_model(cls.parfileB1855)
- logging.debug("%s" % cls.modelB1855.components)
- logging.debug("%s" % cls.modelB1855.params)
+ logging.debug(f"{cls.modelB1855.components}")
+ logging.debug(f"{cls.modelB1855.params}")
# tempo result
cls.ltres = np.genfromtxt(
- cls.parfileB1855 + ".tempo2_test", skip_header=1, unpack=True
+ f"{cls.parfileB1855}.tempo2_test", skip_header=1, unpack=True
)
- def test_B1855(self):
+ def test_b1855(self):
pint_resids_us = Residuals(
self.toasB1855, self.modelB1855, use_weighted_mean=False
).time_resids.to(u.s)
@@ -50,11 +50,11 @@ def test_derivative(self):
testp = tdu.get_derivative_params(self.modelB1855)
delay = self.modelB1855.delay(self.toasB1855)
for p in testp.keys():
- log.debug("Runing derivative for %s", "d_delay_d_" + p)
+ log.debug("Runing derivative for %s", f"d_delay_d_{p}")
ndf = self.modelB1855.d_phase_d_param_num(self.toasB1855, p, testp[p])
adf = self.modelB1855.d_phase_d_param(self.toasB1855, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -62,18 +62,13 @@ def test_derivative(self):
"Derivative test failed at d_delay_d_%s with max relative difference %lf"
% (p, np.nanmax(relative_diff).value)
)
- if p in ["SINI"]:
- tol = 0.7
- else:
- tol = 1e-3
+ tol = 0.7 if p in ["SINI"] else 1e-3
log.debug(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
continue
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_B1855_9yrs.py b/tests/test_B1855_9yrs.py
index 59549ad16..91a848895 100644
--- a/tests/test_B1855_9yrs.py
+++ b/tests/test_B1855_9yrs.py
@@ -1,7 +1,7 @@
"""Various tests to assess the performance of the B1855+09."""
import logging
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -13,11 +13,11 @@
from pinttestdata import datadir
-class TestB1855(unittest.TestCase):
+class TestB1855:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parfileB1855 = "B1855+09_NANOGrav_9yv1.gls.par"
cls.timB1855 = "B1855+09_NANOGrav_9yv1.tim"
@@ -27,10 +27,10 @@ def setUpClass(cls):
cls.modelB1855 = mb.get_model(cls.parfileB1855)
# tempo result
cls.ltres = np.genfromtxt(
- cls.parfileB1855 + ".tempo2_test", skip_header=1, unpack=True
+ f"{cls.parfileB1855}.tempo2_test", skip_header=1, unpack=True
)
- def test_B1855(self):
+ def test_b1855(self):
pint_resids_us = Residuals(
self.toasB1855, self.modelB1855, use_weighted_mean=False
).time_resids.to(u.s)
@@ -44,11 +44,11 @@ def test_derivative(self):
testp = tdu.get_derivative_params(self.modelB1855)
delay = self.modelB1855.delay(self.toasB1855)
for p in testp.keys():
- log.debug("Runing derivative for %s", "d_delay_d_" + p)
+ log.debug("Runing derivative for %s", f"d_delay_d_{p}")
ndf = self.modelB1855.d_phase_d_param_num(self.toasB1855, p, testp[p])
adf = self.modelB1855.d_phase_d_param(self.toasB1855, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -56,13 +56,12 @@ def test_derivative(self):
"Derivative test failed at d_delay_d_%s with max relative difference %lf"
% (p, np.nanmax(relative_diff).value)
)
- if p in ["SINI"]:
- tol = 0.7
- else:
- tol = 1e-3
+ tol = 0.7 if p in ["SINI"] else 1e-3
log.debug(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
diff --git a/tests/test_B1953.py b/tests/test_B1953.py
index 9946b5b1b..6b30077e7 100644
--- a/tests/test_B1953.py
+++ b/tests/test_B1953.py
@@ -1,7 +1,7 @@
"""Various tests to assess the performance of the B1953+29."""
from astropy import log
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -13,11 +13,11 @@
from pinttestdata import datadir
-class TestB1953(unittest.TestCase):
+class TestB1953:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parfileB1953 = "B1953+29_NANOGrav_dfg+12_TAI_FB90.par"
cls.timB1953 = "B1953+29_NANOGrav_dfg+12.tim"
@@ -27,18 +27,18 @@ def setUpClass(cls):
cls.modelB1953 = mb.get_model(cls.parfileB1953)
# tempo result
cls.ltres, cls.ltbindelay = np.genfromtxt(
- cls.parfileB1953 + ".tempo2_test", skip_header=1, unpack=True
+ f"{cls.parfileB1953}.tempo2_test", skip_header=1, unpack=True
)
print(cls.ltres)
- def test_B1953_binary_delay(self):
+ def test_b1953_binary_delay(self):
# Calculate delays with PINT
pint_binary_delay = self.modelB1953.binarymodel_delay(self.toasB1953, None)
assert np.all(
np.abs(pint_binary_delay.value + self.ltbindelay) < 1e-8
), "B1953 binary delay test failed."
- def test_B1953(self):
+ def test_b1953(self):
pint_resids_us = Residuals(
self.toasB1953, self.modelB1953, use_weighted_mean=False
).time_resids.to(u.s)
@@ -52,11 +52,11 @@ def test_derivative(self):
testp = tdu.get_derivative_params(self.modelB1953)
delay = self.modelB1953.delay(self.toasB1953)
for p in testp.keys():
- log.debug("Runing derivative for %s".format("d_delay_d_" + p))
+ log.debug("Runing derivative for %s".format(f"d_delay_d_{p}"))
ndf = self.modelB1953.d_phase_d_param_num(self.toasB1953, p, testp[p])
adf = self.modelB1953.d_phase_d_param(self.toasB1953, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -71,13 +71,11 @@ def test_derivative(self):
else:
tol = 1e-3
log.debug(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
continue
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_Galactic.py b/tests/test_Galactic.py
index f048f5783..dc2e6358c 100644
--- a/tests/test_Galactic.py
+++ b/tests/test_Galactic.py
@@ -1,6 +1,6 @@
import logging
import os
-import unittest
+import pytest
import astropy.units as u
@@ -12,11 +12,11 @@
import astropy.time
-class TestGalactic(unittest.TestCase):
+class TestGalactic:
"""Test conversion from equatorial/ecliptic -> Galactic coordinates as astropy objects"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# J0613 is in equatorial
cls.parfileJ0613 = os.path.join(
datadir, "J0613-0200_NANOGrav_dfg+12_TAI_FB90.par"
@@ -177,7 +177,3 @@ def test_ecliptic_to_galactic(self):
% sep.arcsec
)
assert sep < 1e-9 * u.arcsec, msg
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_J0613.py b/tests/test_J0613.py
index 2181aeb41..1178afc7f 100644
--- a/tests/test_J0613.py
+++ b/tests/test_J0613.py
@@ -1,7 +1,7 @@
"""Various tests to assess the performance of the J0623-0200."""
import logging
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -13,11 +13,11 @@
from pinttestdata import datadir
-class TestJ0613(unittest.TestCase):
+class TestJ0613:
"""Compare delays from the ELL1 model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parfileJ0613 = "J0613-0200_NANOGrav_dfg+12_TAI_FB90.par"
cls.timJ0613 = "J0613-0200_NANOGrav_dfg+12.tim"
@@ -27,18 +27,18 @@ def setUpClass(cls):
cls.modelJ0613 = mb.get_model(cls.parfileJ0613)
# tempo result
cls.ltres, cls.ltbindelay = np.genfromtxt(
- cls.parfileJ0613 + ".tempo2_test", skip_header=1, unpack=True
+ f"{cls.parfileJ0613}.tempo2_test", skip_header=1, unpack=True
)
print(cls.ltres)
- def test_J0613_binary_delay(self):
+ def test_j0613_binary_delay(self):
# Calculate delays with PINT
pint_binary_delay = self.modelJ0613.binarymodel_delay(self.toasJ0613, None)
assert np.all(
np.abs(pint_binary_delay.value + self.ltbindelay) < 1e-8
), "J0613 binary delay test failed."
- def test_J0613(self):
+ def test_j0613(self):
pint_resids_us = Residuals(
self.toasJ0613, self.modelJ0613, use_weighted_mean=False
).time_resids.to(u.s)
@@ -61,11 +61,11 @@ def test_derivative(self):
testp["PMDEC"] = 1
testp["PMRA"] = 1
for p in testp.keys():
- log.debug("Runing derivative for %s", "d_delay_d_" + p)
+ log.debug("Runing derivative for %s", f"d_delay_d_{p}")
ndf = self.modelJ0613.d_phase_d_param_num(self.toasJ0613, p, testp[p])
adf = self.modelJ0613.d_phase_d_param(self.toasJ0613, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -73,18 +73,13 @@ def test_derivative(self):
"Derivative test failed at d_delay_d_%s with max relative difference %lf"
% (p, np.nanmax(relative_diff).value)
)
- if p in ["EPS1DOT", "EPS1"]:
- tol = 0.05
- else:
- tol = 1e-3
+ tol = 0.05 if p in ["EPS1DOT", "EPS1"] else 1e-3
log.debug(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
continue
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_TDB_method.py b/tests/test_TDB_method.py
index 927ae72d8..a25670a2d 100644
--- a/tests/test_TDB_method.py
+++ b/tests/test_TDB_method.py
@@ -1,6 +1,6 @@
"""tests for different compute TDB method."""
import os
-import unittest
+import pytest
import numpy as np
@@ -8,11 +8,11 @@
from pinttestdata import datadir
-class TestTDBMethod(unittest.TestCase):
+class TestTDBMethod:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.tim = "B1855+09_NANOGrav_9yv1.tim"
diff --git a/tests/test_absphase.py b/tests/test_absphase.py
index d6d3d161a..f37615f08 100644
--- a/tests/test_absphase.py
+++ b/tests/test_absphase.py
@@ -1,6 +1,6 @@
-#!/usr/bin/python
+import pytest
import os
-import unittest
+import pytest
import pint.models
import pint.toa
@@ -10,7 +10,7 @@
timfile = os.path.join(datadir, "zerophase.tim")
-class TestAbsPhase(unittest.TestCase):
+class TestAbsPhase:
def test_phase_zero(self):
# Check that model phase is 0.0 for a TOA at exactly the TZRMJD
model = pint.models.get_model(parfile)
@@ -18,5 +18,13 @@ def test_phase_zero(self):
ph = model.phase(toas, abs_phase=True)
# Check that integer and fractional phase values are very close to 0.0
- self.assertAlmostEqual(ph.int.value, 0.0)
- self.assertAlmostEqual(ph.frac.value, 0.0)
+ assert ph.int.value == pytest.approx(0.0)
+ assert ph.frac.value == pytest.approx(0.0)
+
+
+def test_tzr_attr():
+ model = pint.models.get_model(parfile)
+ toas = pint.toa.get_TOAs(timfile)
+
+ assert not toas.tzr
+ assert model.components["AbsPhase"].get_TZR_toa(toas).tzr
diff --git a/tests/test_all_component_and_model_builder.py b/tests/test_all_component_and_model_builder.py
index ceb8d09ae..7040d795b 100644
--- a/tests/test_all_component_and_model_builder.py
+++ b/tests/test_all_component_and_model_builder.py
@@ -6,7 +6,6 @@
import copy
from os.path import basename, join
import numpy as np
-import astropy.units as u
from pint.models.timing_model import (
TimingModel,
PhaseComponent,
@@ -74,7 +73,7 @@ def simple_model_alias_overlap():
@pytest.fixture
def test_timing_model():
ac = AllComponents()
- timing_model = TimingModel(
+ return TimingModel(
name="Test",
components=[
ac.components["AstrometryEquatorial"],
@@ -84,7 +83,6 @@ def test_timing_model():
ac.components["ScaleToaError"],
],
)
- return timing_model
pint_dict_base = {
@@ -115,9 +113,9 @@ def test_model_builder_class():
def test_aliases_mapping():
- """Test if aliases gets mapped correclty"""
+ """Test if aliases gets mapped correctly"""
mb = AllComponents()
- # all alases should be mapped to the components
+ # all aliases should be mapped to the components
assert set(mb._param_alias_map.keys()) == set(mb.param_component_map.keys())
# Test if the param_alias_map is passed by pointer
@@ -126,10 +124,10 @@ def test_aliases_mapping():
# assert "TESTAX" in mb._param_alias_map
# Test existing entry
# When adding an existing alias to the map. The mapped value should be the
- # same, otherwrise it will fail.
+ # same, otherwise it will fail.
mb._check_alias_conflict("F0", "F0", mb._param_alias_map)
# assert mb._param_alias_map["F0"] == "F0"
- # Test repeatable_params with differnt indices.
+ # Test repeatable_params with different indices.
for rp in mb.repeatable_param:
pint_par, first_init_par = mb.alias_to_pint_param(rp)
cp = mb.param_component_map[pint_par][0]
@@ -139,10 +137,10 @@ def test_aliases_mapping():
except PrefixError:
prefix = rp
- new_idx_par = prefix + "2"
- assert mb.alias_to_pint_param(new_idx_par)[0] == pint_par_obj.prefix + "2"
- new_idx_par = prefix + "55"
- assert mb.alias_to_pint_param(new_idx_par)[0] == pint_par_obj.prefix + "55"
+ new_idx_par = f"{prefix}2"
+ assert mb.alias_to_pint_param(new_idx_par)[0] == f"{pint_par_obj.prefix}2"
+ new_idx_par = f"{prefix}55"
+ assert mb.alias_to_pint_param(new_idx_par)[0] == f"{pint_par_obj.prefix}55"
# Test all aliases
for als in pint_par_obj.aliases:
assert mb.alias_to_pint_param(als)[0] == pint_par_obj.name
@@ -150,14 +148,14 @@ def test_aliases_mapping():
als_prefix, id, ids = split_prefixed_name(als)
except PrefixError:
als_prefix = als
- assert mb.alias_to_pint_param(als_prefix + "2")[0] == pint_par_obj.prefix + "2"
+ assert mb.alias_to_pint_param(f"{als_prefix}2")[0] == f"{pint_par_obj.prefix}2"
assert (
- mb.alias_to_pint_param(als_prefix + "55")[0] == pint_par_obj.prefix + "55"
+ mb.alias_to_pint_param(f"{als_prefix}55")[0] == f"{pint_par_obj.prefix}55"
)
def test_conflict_alias():
- """Test if model builder detects the alais conflict."""
+ """Test if model builder detects the alias conflict."""
mb = AllComponents()
# Test conflict parameter alias name
with pytest.raises(AliasConflict):
@@ -165,7 +163,7 @@ def test_conflict_alias():
def test_conflict_alias_in_component():
- # Define conflict alais from component class
+ # Define conflict alias from component class
class SimpleModel2(PhaseComponent):
"""Very simple test model component"""
@@ -190,7 +188,7 @@ def test_overlap_component(simple_model_overlap, simple_model_alias_overlap):
# Test overlap
overlap = mb._get_component_param_overlap(simple_model_overlap)
assert "Spindown" in overlap.keys()
- assert overlap["Spindown"][0] == set(["F0"])
+ assert overlap["Spindown"][0] == {"F0"}
# Only one over lap parameter F0
# Since the _get_component_param_overlap returns non-overlap part,
# we test if the non-overlap number makes sense.
@@ -201,13 +199,13 @@ def test_overlap_component(simple_model_overlap, simple_model_alias_overlap):
)
a_overlap = mb._get_component_param_overlap(simple_model_alias_overlap)
- assert a_overlap["Spindown"][0] == set(["F0"])
+ assert a_overlap["Spindown"][0] == {"F0"}
assert a_overlap["Spindown"][1] == len(simple_model_alias_overlap.params) - 1
assert (
a_overlap["Spindown"][2]
== len(mb.all_components.components["Spindown"].params) - 1
)
- assert a_overlap["AstrometryEcliptic"][0] == set(["ELONG"])
+ assert a_overlap["AstrometryEcliptic"][0] == {"ELONG"}
assert (
a_overlap["AstrometryEcliptic"][1] == len(simple_model_alias_overlap.params) - 1
)
@@ -411,3 +409,25 @@ def test_all_parfiles(parfile):
if basename(parfile) in bad_trouble:
pytest.skip("This parfile is unclear")
model = get_model(parfile)
+
+
+def test_include_solar_system_shapiro():
+ par = "F0 100 1"
+ m = get_model(io.StringIO(par))
+ assert "SolarSystemShapiro" not in m.components
+
+ par = """
+ ELAT 0.1 1
+ ELONG 2.1 1
+ F0 100 1
+ """
+ m = get_model(io.StringIO(par))
+ assert "SolarSystemShapiro" in m.components
+
+ par = """
+ RAJ 06:00:00 1
+ DECJ 12:00:00 1
+ F0 100 1
+ """
+ m = get_model(io.StringIO(par))
+ assert "SolarSystemShapiro" in m.components
diff --git a/tests/test_astropyobservatory.py b/tests/test_astropy_observatory.py
similarity index 93%
rename from tests/test_astropyobservatory.py
rename to tests/test_astropy_observatory.py
index bbcb05f9e..eed66d4fc 100644
--- a/tests/test_astropyobservatory.py
+++ b/tests/test_astropy_observatory.py
@@ -1,15 +1,16 @@
+import pytest
import logging
-import unittest
+import pytest
import numpy as np
import pint.observatory
-class TestAstropyObservatory(unittest.TestCase):
+class TestAstropyObservatory:
"""
Test fallback from PINT observatories to astropy observatories."""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# name and ITRF of an observatory that PINT should know about
cls.pint_obsname = "gbt"
cls.pint_itrf = [882589.65, -4924872.32, 3943729.348]
@@ -70,4 +71,4 @@ def test_missing_observatory(self):
"""
try to instantiate a missing observatory.
"""
- self.assertRaises(KeyError, pint.observatory.Observatory.get, self.none_obsname)
+ pytest.raises(KeyError, pint.observatory.Observatory.get, self.none_obsname)
diff --git a/tests/test_astropy_times.py b/tests/test_astropy_times.py
index d21b5c307..0f3e5cabd 100644
--- a/tests/test_astropy_times.py
+++ b/tests/test_astropy_times.py
@@ -1,4 +1,4 @@
-import unittest
+import pytest
import astropy.coordinates
import astropy.time
@@ -6,7 +6,7 @@
from astropy.utils.iers import IERS_A, IERS_A_URL
-class TestAstroPyTime(unittest.TestCase):
+class TestAstroPyTime:
"""This class contains a sequence of time conversion tests.
From the SOFA manual, these times are all equivalent:
@@ -19,7 +19,7 @@ class TestAstroPyTime(unittest.TestCase):
TCB 2006/01/15 21:25:56.893952
"""
- def setUp(self):
+ def setup_method(self):
self.lat = 19.48125
self.lon = -155.933222
earthloc = astropy.coordinates.EarthLocation.from_geodetic(
@@ -38,7 +38,7 @@ def test_utc(self):
y = "2006-01-15 21:24:37.500000"
assert x == y
- @unittest.skip
+ @pytest.mark.skip
def test_utc_ut1(self):
x = self.t.ut1.iso
y = "2006-01-15 21:24:37.834078"
@@ -70,7 +70,7 @@ def test_tt_tcb(self):
y = "2006-01-15 21:25:56.893952"
assert x == y
- @unittest.skip
+ @pytest.mark.skip
def test_iers_a_now():
# FIXME: might use cached IERS_A_URL?
# FIXME: what would this actually be testing?
diff --git a/tests/test_astropyversion.py b/tests/test_astropy_version.py
similarity index 100%
rename from tests/test_astropyversion.py
rename to tests/test_astropy_version.py
diff --git a/tests/test_barytoa.py b/tests/test_barytoa.py
index 699c2f470..e7f309969 100644
--- a/tests/test_barytoa.py
+++ b/tests/test_barytoa.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
import os
import astropy.units as u
diff --git a/tests/test_bayesian.py b/tests/test_bayesian.py
index c02b1898d..fdeff5e7c 100644
--- a/tests/test_bayesian.py
+++ b/tests/test_bayesian.py
@@ -43,8 +43,7 @@ def data_NGC6440E_efac():
timfile = examplefile("NGC6440E.tim")
model, toas = get_model_and_toas(parfile, timfile)
- parfile = str(model)
- parfile += "EFAC TEL gbt 1 1"
+ parfile = f"{str(model)}EFAC TEL gbt 1 1"
model = get_model(io.StringIO(parfile))
set_dummy_priors(model)
@@ -67,9 +66,11 @@ def data_NGC6440E_red():
parfile = examplefile("NGC6440E.par.good")
timfile = examplefile("NGC6440E.tim")
model, toas = get_model_and_toas(parfile, timfile)
- parfile = str(model)
- parfile += """RNAMP 1e-6
- RNIDX -3.82398"""
+ parfile = (
+ str(model)
+ + """RNAMP 1e-6
+ RNIDX -3.82398"""
+ )
model = get_model(io.StringIO(parfile))
set_dummy_priors(model)
diff --git a/tests/test_binary_generic.py b/tests/test_binary_generic.py
index c2c276aa6..dd7d5e352 100644
--- a/tests/test_binary_generic.py
+++ b/tests/test_binary_generic.py
@@ -20,5 +20,5 @@ def test_if_stand_alone_binary_model_get_updated_from_PINT_model(parfile):
try:
m = get_model(parfile)
except (ValueError, IOError, MissingParameter) as e:
- pytest.skip("Existing code raised an exception {}".format(e))
+ pytest.skip(f"Existing code raised an exception {e}")
verify_stand_alone_binary_parameter_updates(m)
diff --git a/tests/test_binconvert.py b/tests/test_binconvert.py
new file mode 100644
index 000000000..380c69591
--- /dev/null
+++ b/tests/test_binconvert.py
@@ -0,0 +1,327 @@
+from astropy import units as u, constants as c
+import astropy.time
+import numpy as np
+import io
+import os
+import copy
+import pytest
+
+from pint.models import get_model
+from pint import derived_quantities
+import pint.simulation
+import pint.fitter
+import pint.binaryconvert
+
+parDD = """
+PSRJ 1855+09
+RAJ 18:57:36.3932884 0 0.00002602730280675029
+DECJ +09:43:17.29196 0 0.00078789485676919773
+F0 186.49408156698235146 0 0.00000000000698911818
+F1 -6.2049547277487420583e-16 0 1.7380934373573401505e-20
+PEPOCH 49453
+POSEPOCH 49453
+DMEPOCH 49453
+DM 13.29709
+PMRA -2.5054345161030380639 0 0.03104958261053317181
+PMDEC -5.4974558631993817232 0 0.06348008663748286318
+PX 1.2288569063263405232 0 0.21243361289239687251
+T0 49452.940695077335647 0 0.00169031830532837251
+OM 276.55142180589701234 0 0.04936551005019605698
+ECC 0.1 0 0.00000004027191312623
+START 53358.726464889485214
+FINISH 55108.922917417192366
+TZRMJD 54177.508359343262555
+TZRFRQ 424
+TZRSITE ao
+TRES 0.395
+EPHVER 5
+CLK TT(TAI)
+MODE 1
+UNITS TDB
+T2CMETHOD TEMPO
+#NE_SW 0.000
+CORRECT_TROPOSPHERE N
+EPHEM DE405
+NITS 1
+NTOA 702
+CHI2R 2.1896 637
+SOLARN0 00.00
+TIMEEPH FB90
+PLANET_SHAPIRO N
+EDOT 2e-10 1 2e-12
+"""
+Mp = 1.4 * u.Msun
+Mc = 1.1 * u.Msun
+i = 85 * u.deg
+PB = 0.5 * u.day
+A1 = derived_quantities.a1sini(Mp, Mc, PB, i)
+
+parELL1 = """PSR B1855+09
+LAMBDA 286.8634874826803 1 0.0000000103957
+BETA 32.3214851782886 1 0.0000000165796
+PMLAMBDA -3.2697 1 0.0079
+PMBETA -5.0683 1 0.0154
+PX 0.7135 1 0.1221
+POSEPOCH 55637.0000
+F0 186.4940812354533364 1 0.0000000000087885
+F1 -6.204846776906D-16 1 4.557200069514D-20
+PEPOCH 55637.000000
+START 53358.726
+FINISH 57915.276
+DM 13.313704
+OLARN0 0.00
+EPHEM DE436
+ECL IERS2010
+CLK TT(BIPM2017)
+UNITS TDB
+TIMEEPH FB90
+T2CMETHOD TEMPO
+CORRECT_TROPOSPHERE N
+PLANET_SHAPIRO N
+DILATEFREQ N
+NTOA 313
+TRES 2.44
+TZRMJD 55638.45920097834544
+TZRFRQ 1389.800
+TZRSITE AO
+MODE 1
+NITS 1
+DMDATA 1
+INFO -f
+BINARY ELL1
+A1 9.230780257 1 0.000000172
+PB 12.32717119177539 1 0.00000000014613
+TASC 55631.710921347 1 0.000000017
+EPS1 -0.0000215334 1 0.0000000194
+EPS2 0.0000024177 1 0.0000000127
+SINI 0.999185 1 0.000190
+M2 0.246769 1 0.009532
+EPS1DOT 1e-10 1 1e-11
+EPS2DOT -1e-10 1 1e-11
+"""
+
+parELL1FB0 = """PSR B1855+09
+LAMBDA 286.8634874826803 1 0.0000000103957
+BETA 32.3214851782886 1 0.0000000165796
+PMLAMBDA -3.2697 1 0.0079
+PMBETA -5.0683 1 0.0154
+PX 0.7135 1 0.1221
+POSEPOCH 55637.0000
+F0 186.4940812354533364 1 0.0000000000087885
+F1 -6.204846776906D-16 1 4.557200069514D-20
+PEPOCH 55637.000000
+START 53358.726
+FINISH 57915.276
+DM 13.313704
+OLARN0 0.00
+EPHEM DE436
+ECL IERS2010
+CLK TT(BIPM2017)
+UNITS TDB
+TIMEEPH FB90
+T2CMETHOD TEMPO
+CORRECT_TROPOSPHERE N
+PLANET_SHAPIRO N
+DILATEFREQ N
+NTOA 313
+TRES 2.44
+TZRMJD 55638.45920097834544
+TZRFRQ 1389.800
+TZRSITE AO
+MODE 1
+NITS 1
+DMDATA 1
+INFO -f
+BINARY ELL1
+A1 9.230780257 1 0.000000172
+#PB 12.32717119177539 1 0.00000000014613
+FB0 9.389075477264583e-07 1 1.1130092850564776e-17
+TASC 55631.710921347 1 0.000000017
+EPS1 -0.0000215334 1 0.0000000194
+EPS2 0.0000024177 1 0.0000000127
+SINI 0.999185 1 0.000190
+M2 0.246769 1 0.009532
+EPS1DOT 1e-10 1 1e-11
+EPS2DOT -1e-10 1 1e-11
+"""
+
+kwargs = {"ELL1H": {"NHARMS": 3, "useSTIGMA": True}, "DDK": {"KOM": 0 * u.deg}}
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_ELL1(output):
+ m = get_model(io.StringIO(parELL1))
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_ELL1_roundtrip(output):
+ m = get_model(io.StringIO(parELL1))
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+ mback = pint.binaryconvert.convert_binary(mout, "ELL1")
+ for p in m.params:
+ if output == "BT" and p in ["M2", "SINI"]:
+ # these are not in BT
+ continue
+ if getattr(m, p).value is None:
+ continue
+ if not isinstance(getattr(m, p).quantity, (str, bool, astropy.time.Time)):
+ assert np.isclose(
+ getattr(m, p).value, getattr(mback, p).value
+ ), f"{p}: {getattr(m, p).value} does not match {getattr(mback, p).value}"
+ if getattr(m, p).uncertainty is not None:
+ # some precision may be lost in uncertainty conversion
+ assert np.isclose(
+ getattr(m, p).uncertainty_value,
+ getattr(mback, p).uncertainty_value,
+ rtol=0.2,
+ ), f"{p} uncertainty: {getattr(m, p).uncertainty_value} does not match {getattr(mback, p).uncertainty_value}"
+ else:
+ assert (
+ getattr(m, p).value == getattr(mback, p).value
+ ), f"{p}: {getattr(m, p).value} does not match {getattr(mback, p).value}"
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_ELL1FB0(output):
+ m = get_model(io.StringIO(parELL1FB0))
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_ELL1_roundtripFB0(output):
+ m = get_model(io.StringIO(parELL1FB0))
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+ mback = pint.binaryconvert.convert_binary(mout, "ELL1")
+ for p in m.params:
+ if output == "BT" and p in ["M2", "SINI"]:
+ # these are not in BT
+ continue
+ if getattr(m, p).value is None:
+ continue
+ if not isinstance(getattr(m, p).quantity, (str, bool, astropy.time.Time)):
+ assert np.isclose(
+ getattr(m, p).value, getattr(mback, p).value
+ ), f"{p}: {getattr(m, p).value} does not match {getattr(mback, p).value}"
+ if getattr(m, p).uncertainty is not None:
+ # some precision may be lost in uncertainty conversion
+ assert np.isclose(
+ getattr(m, p).uncertainty_value,
+ getattr(mback, p).uncertainty_value,
+ rtol=0.2,
+ ), f"{p} uncertainty: {getattr(m, p).uncertainty_value} does not match {getattr(mback, p).uncertainty_value}"
+ else:
+ assert (
+ getattr(m, p).value == getattr(mback, p).value
+ ), f"{p}: {getattr(m, p).value} does not match {getattr(mback, p).value}"
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1k", "ELL1H", "DD", "BT", "DDS", "DDK"])
+def test_DD(output):
+ m = get_model(
+ io.StringIO(
+ f"{parDD}\nBINARY DD\nSINI {np.sin(i).value}\nA1 {A1.value}\nPB {PB.value}\nM2 {Mc.value}\n"
+ )
+ )
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_DD_roundtrip(output):
+ s = f"{parDD}\nBINARY DD\nSINI {np.sin(i).value} 1 0.01\nA1 {A1.value}\nPB {PB.value} 1 0.1\nM2 {Mc.value} 1 0.01\n"
+ if output not in ["ELL1", "ELL1H"]:
+ s += "OMDOT 1e-10 1 1e-12"
+
+ m = get_model(io.StringIO(s))
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+ mback = pint.binaryconvert.convert_binary(mout, "DD")
+ for p in m.params:
+ if output == "BT" and p in ["M2", "SINI"]:
+ # these are not in BT
+ continue
+ if getattr(m, p).value is None:
+ continue
+ # print(getattr(m, p), getattr(mback, p))
+ if not isinstance(getattr(m, p).quantity, (str, bool, astropy.time.Time)):
+ assert np.isclose(getattr(m, p).value, getattr(mback, p).value)
+ if getattr(m, p).uncertainty is not None:
+ # some precision may be lost in uncertainty conversion
+ if output in ["ELL1", "ELL1H", "ELL1k"] and p in ["ECC"]:
+ # we lose precision on ECC since it also contains a contribution from OM now
+ continue
+ if output == "ELL1H" and p == "M2":
+ # this also loses precision
+ continue
+ assert np.isclose(
+ getattr(m, p).uncertainty_value,
+ getattr(mback, p).uncertainty_value,
+ rtol=0.2,
+ )
+ else:
+ assert getattr(m, p).value == getattr(mback, p).value
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_DDGR(output):
+ m = get_model(
+ io.StringIO(
+ f"{parDD}\nBINARY DDGR\nA1 {A1.value} 0 0.01\nPB {PB.value} 0 0.02\nM2 {Mc.value} \nMTOT {(Mp+Mc).value}\n"
+ )
+ )
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+
+
+@pytest.mark.parametrize(
+ "output",
+ [
+ "ELL1",
+ "ELL1k",
+ "ELL1H",
+ "DD",
+ "BT",
+ "DDS",
+ "DDK",
+ ],
+)
+def test_DDFB0(output):
+ m = get_model(
+ io.StringIO(
+ f"{parDD}\nBINARY DD\nSINI {np.sin(i).value}\nA1 {A1.value}\nFB0 {(1/PB).to_value(u.Hz)}\nM2 {Mc.value}\n"
+ )
+ )
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+
+
+@pytest.mark.parametrize("output", ["ELL1", "ELL1H", "ELL1k", "DD", "BT", "DDS", "DDK"])
+def test_DDFB0_roundtrip(output):
+ s = f"{parDD}\nBINARY DD\nSINI {np.sin(i).value} 1 0.01\nA1 {A1.value}\nFB0 {(1/PB).to_value(u.Hz)} 1 0.1\nM2 {Mc.value} 1 0.01\n"
+ if output not in ["ELL1", "ELL1H"]:
+ s += "OMDOT 1e-10 1 1e-12"
+
+ m = get_model(io.StringIO(s))
+ mout = pint.binaryconvert.convert_binary(m, output, **kwargs.get(output, {}))
+ mback = pint.binaryconvert.convert_binary(mout, "DD")
+ for p in m.params:
+ if output == "BT" and p in ["M2", "SINI"]:
+ # these are not in BT
+ continue
+ if getattr(m, p).value is None:
+ continue
+ # print(getattr(m, p), getattr(mback, p))
+ if not isinstance(getattr(m, p).quantity, (str, bool, astropy.time.Time)):
+ assert np.isclose(getattr(m, p).value, getattr(mback, p).value)
+ if getattr(m, p).uncertainty is not None:
+ # some precision may be lost in uncertainty conversion
+ if output in ["ELL1", "ELL1H", "ELL1k"] and p in ["ECC"]:
+ # we lose precision on ECC since it also contains a contribution from OM now
+ continue
+ if output == "ELL1H" and p == "M2":
+ # this also loses precision
+ continue
+ assert np.isclose(
+ getattr(m, p).uncertainty_value,
+ getattr(mback, p).uncertainty_value,
+ rtol=0.2,
+ )
+ else:
+ assert getattr(m, p).value == getattr(mback, p).value
diff --git a/tests/test_clockcorr.py b/tests/test_clockcorr.py
index 82aea76e2..ada505802 100644
--- a/tests/test_clockcorr.py
+++ b/tests/test_clockcorr.py
@@ -1,5 +1,6 @@
-import unittest
+import pytest
from os import path
+import io
import astropy.units as u
import numpy
@@ -9,9 +10,10 @@
from pint.observatory import Observatory
from pint.observatory.clock_file import ClockFile
+from pint.toa import get_TOAs
-class TestClockcorrection(unittest.TestCase):
+class TestClockcorrection:
# Note, these tests currently depend on external data (TEMPO2 clock
# files, which could potentially change. Values here are taken
# from tempo2 version 2020-06-01 or so.
@@ -45,3 +47,23 @@ def test_wsrt_parsed_correctly_with_text_columns(self):
with pytest.raises(RuntimeError):
t = cf.time[-1] + 1.0 * u.d
cf.evaluate(t, limits="error")
+
+
+def test_clockcorr_roundtrip():
+ timlines = """FORMAT 1
+ toa1 1400 55555.0 1.0 gbt
+ toa2 1400 55556.0 1.0 gbt"""
+ t = get_TOAs(io.StringIO(timlines))
+ # should have positive clock correction applied
+ assert t.get_mjds()[0].value > 55555
+ assert t.get_mjds()[1].value > 55556
+ o = io.StringIO()
+ t.write_TOA_file(o)
+ o.seek(0)
+ lines = o.readlines()
+ # make sure the clock corrections are no longer there.
+ for line in lines:
+ if line.startswith("toa1"):
+ assert float(line.split()[2]) == 55555
+ if line.startswith("toa2"):
+ assert float(line.split()[2]) == 55556
diff --git a/tests/test_compare.py b/tests/test_compare.py
index 2252e0ad2..ce5f3d488 100644
--- a/tests/test_compare.py
+++ b/tests/test_compare.py
@@ -1,4 +1,4 @@
-import unittest
+import pytest
import astropy
import pint
import pint.models as mod
@@ -9,7 +9,7 @@
from pinttestdata import datadir
-class TestCompare(unittest.TestCase):
+class TestCompare:
"""Test model comparison method"""
def test_paramchange(self):
@@ -47,21 +47,19 @@ def test_paramchange(self):
param_cp.quantity = (
param_cp.quantity + factor * param_cp.uncertainty
)
+ elif isinstance(param, pint.models.parameter.boolParameter):
+ param.value = not param.value
+ elif isinstance(param, pint.models.parameter.intParameter):
+ param.value += 1
+ elif param_cp.quantity != 0:
+ param_cp.quantity = 1.1 * param_cp.quantity
else:
- if isinstance(param, pint.models.parameter.boolParameter):
- param.value = not param.value
- elif isinstance(param, pint.models.parameter.intParameter):
- param.value += 1
- elif param_cp.quantity != 0:
- param_cp.quantity = 1.1 * param_cp.quantity
- else:
- param_cp.value += 3.0
+ param_cp.value += 3.0
model.compare(
modelcp, threshold_sigma=threshold_sigma, verbosity=verbosity
)
if not accumulate_changes:
modelcp = cp(model)
- assert True, "Failure in parameter changing test"
def test_uncertaintychange(self):
# This test changes each parameter's uncertainty by the "factor" below.
@@ -93,9 +91,7 @@ def test_uncertaintychange(self):
or param_cp.quantity is None
):
continue
- if param_cp.uncertainty != None:
- param_cp.uncertainty = factor * param_cp.uncertainty
- else:
+ if param_cp.uncertainty is None:
if isinstance(param, pint.models.parameter.boolParameter):
param.value = not param.value
elif isinstance(param, pint.models.parameter.intParameter):
@@ -103,10 +99,11 @@ def test_uncertaintychange(self):
else:
param.uncertainty = 0 * param.units
param_cp.uncertainty = 3.0 * param_cp.units
+ else:
+ param_cp.uncertainty = factor * param_cp.uncertainty
model.compare(modelcp, threshold_sigma=3.0, verbosity=verbosity)
if not accumulate_changes:
modelcp = cp(model)
- assert True, "Failure in uncertainty changing test"
def test_missing_uncertainties(self):
# Removes uncertainties from both models and attempts to use compare.
@@ -161,4 +158,3 @@ def test_missing_uncertainties(self):
model_2.compare(model_1)
model_1 = mod.get_model(io.StringIO(par_base1))
model_2 = mod.get_model(io.StringIO(par_base2))
- assert True, "Failure in missing uncertainty test"
diff --git a/tests/test_copy.py b/tests/test_copy.py
index 43ffd804e..a2f180d91 100644
--- a/tests/test_copy.py
+++ b/tests/test_copy.py
@@ -1,5 +1,4 @@
-""" Test for pint object copying
-"""
+""" Test for PINT object copying"""
import os
import pytest
diff --git a/tests/test_covariance_matrix.py b/tests/test_covariance_matrix.py
index 81aab55f9..c8ba789ff 100644
--- a/tests/test_covariance_matrix.py
+++ b/tests/test_covariance_matrix.py
@@ -1,5 +1,4 @@
-""" Various of tests for the pint covariance.
-"""
+""" Various of tests for the pint covariance."""
import pytest
import os
@@ -16,7 +15,7 @@
class TestCovarianceMatrix:
"""Test for covariance matrix"""
- def setup(self):
+ def setup_method(self):
self.matrix1 = np.arange(16).reshape((4, 4))
self.label1 = [{"c": (0, 4, u.s)}] * 2
self.matrix2 = np.arange(9).reshape((3, 3))
diff --git a/tests/test_d_phase_d_toa.py b/tests/test_d_phase_d_toa.py
index 66e14c7a6..179b162a0 100644
--- a/tests/test_d_phase_d_toa.py
+++ b/tests/test_d_phase_d_toa.py
@@ -1,5 +1,5 @@
import os
-import unittest
+import pytest
import numpy as np
@@ -11,20 +11,20 @@
from pinttestdata import datadir, testdir
-class TestD_phase_d_toa(unittest.TestCase):
+class TestD_phase_d_toa:
@classmethod
- def setUpClass(self):
+ def setup_class(cls):
os.chdir(datadir)
- self.parfileB1855 = "B1855+09_polycos.par"
- self.timB1855 = "B1855_polyco.tim"
- self.toasB1855 = toa.get_TOAs(
- self.timB1855, ephem="DE405", planets=False, include_bipm=False
+ cls.parfileB1855 = "B1855+09_polycos.par"
+ cls.timB1855 = "B1855_polyco.tim"
+ cls.toasB1855 = toa.get_TOAs(
+ cls.timB1855, ephem="DE405", planets=False, include_bipm=False
)
- self.modelB1855 = mb.get_model(self.parfileB1855)
+ cls.modelB1855 = mb.get_model(cls.parfileB1855)
# Read tempo style polycos.
- self.plc = Polycos().read("B1855_polyco.dat", "tempo")
+ cls.plc = Polycos().read("B1855_polyco.dat", "tempo")
- def testD_phase_d_toa(self):
+ def test_d_phase_d_toa(self):
pint_d_phase_d_toa = self.modelB1855.d_phase_d_toa(self.toasB1855)
mjd = np.array(
[
@@ -36,7 +36,3 @@ def testD_phase_d_toa(self):
diff = pint_d_phase_d_toa.value - tempo_d_phase_d_toa
relative_diff = diff / tempo_d_phase_d_toa
assert np.all(np.abs(relative_diff) < 1e-7), "d_phase_d_toa test failed."
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_datafiles.py b/tests/test_datafiles.py
index 0ed5fe94f..6f85eb9f9 100644
--- a/tests/test_datafiles.py
+++ b/tests/test_datafiles.py
@@ -1,4 +1,5 @@
"""Test installation of PINT data files"""
+
import os
import pytest
import pint.config
diff --git a/tests/test_dd.py b/tests/test_dd.py
index aafa43c69..8bbce7f2d 100644
--- a/tests/test_dd.py
+++ b/tests/test_dd.py
@@ -1,6 +1,6 @@
"""Various tests to assess the performance of the DD model."""
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -12,11 +12,11 @@
from copy import deepcopy
-class TestDD(unittest.TestCase):
+class TestDD:
"""Compare delays from the dd model with libstempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parfileB1855 = "B1855+09_NANOGrav_dfg+12_modified_DD.par"
cls.timB1855 = "B1855+09_NANOGrav_dfg+12.tim"
@@ -29,7 +29,7 @@ def setUpClass(cls):
f"{cls.parfileB1855}.tempo_test", unpack=True
)
- def test_J1855_binary_delay(self):
+ def test_j1855_binary_delay(self):
# Calculate delays with PINT
pint_binary_delay = self.modelB1855.binarymodel_delay(self.toasB1855, None)
assert np.all(
@@ -37,7 +37,7 @@ def test_J1855_binary_delay(self):
), "DD B1855 TEST FAILED"
# TODO: PINT can still increase the precision by adding more components
- def test_B1855(self):
+ def test_b1855(self):
pint_resids_us = Residuals(
self.toasB1855, self.modelB1855, use_weighted_mean=False
).time_resids.to(u.s)
diff --git a/tests/test_ddgr.py b/tests/test_ddgr.py
new file mode 100644
index 000000000..d436a6f7b
--- /dev/null
+++ b/tests/test_ddgr.py
@@ -0,0 +1,324 @@
+from astropy import units as u, constants as c
+import numpy as np
+import io
+import os
+import copy
+
+from pint.models import get_model
+from pint import derived_quantities
+import pint.simulation
+import pint.fitter
+
+par = """
+PSRJ 1855+09
+RAJ 18:57:36.3932884 0 0.00002602730280675029
+DECJ +09:43:17.29196 0 0.00078789485676919773
+F0 186.49408156698235146 0 0.00000000000698911818
+F1 -6.2049547277487420583e-16 0 1.7380934373573401505e-20
+PEPOCH 49453
+POSEPOCH 49453
+DMEPOCH 49453
+DM 13.29709
+PMRA -2.5054345161030380639 0 0.03104958261053317181
+PMDEC -5.4974558631993817232 0 0.06348008663748286318
+PX 1.2288569063263405232 0 0.21243361289239687251
+T0 49452.940695077335647 0 0.00169031830532837251
+OM 276.55142180589701234 0 0.04936551005019605698
+ECC 0.1 0 0.00000004027191312623
+START 53358.726464889485214
+FINISH 55108.922917417192366
+TZRMJD 54177.508359343262555
+TZRFRQ 424
+TZRSITE ao
+TRES 0.395
+EPHVER 5
+CLK TT(TAI)
+MODE 1
+UNITS TDB
+T2CMETHOD TEMPO
+#NE_SW 0.000
+CORRECT_TROPOSPHERE N
+EPHEM DE405
+NITS 1
+NTOA 702
+CHI2R 2.1896 637
+SOLARN0 00.00
+TIMEEPH FB90
+PLANET_SHAPIRO N
+"""
+Mp = 1.4 * u.Msun
+Mc = 1.1 * u.Msun
+i = 75 * u.deg
+PB = 0.5 * u.day
+
+
+class TestDDGR:
+ def setup_method(self):
+ A1 = derived_quantities.a1sini(Mp, Mc, PB, i)
+ self.mDD = get_model(
+ io.StringIO(
+ f"{par}\nBINARY DD\nSINI {np.sin(i).value}\nA1 {A1.value}\nPB {PB.value}\nM2 {Mc.value}\n"
+ )
+ )
+ self.mDDGR = get_model(
+ io.StringIO(
+ f"{par}\nBINARY DDGR\nA1 {A1.value}\nPB {PB.value}\nM2 {Mc.value}\nMTOT {(Mp+Mc).value}\n"
+ )
+ )
+
+ def test_pkparameters(self):
+ pbdot = derived_quantities.pbdot(
+ Mp, Mc, self.mDD.PB.quantity, self.mDD.ECC.quantity
+ )
+ gamma = derived_quantities.gamma(
+ Mp, Mc, self.mDD.PB.quantity, self.mDD.ECC.quantity
+ )
+ omdot = derived_quantities.omdot(
+ Mp, Mc, self.mDD.PB.quantity, self.mDD.ECC.quantity
+ )
+ assert np.isclose(gamma, self.mDDGR.GAMMA.quantity)
+ assert np.isclose(pbdot, self.mDDGR.PBDOT.quantity)
+ assert np.isclose(omdot, self.mDDGR.OMDOT.quantity)
+
+ def test_binarydelay(self):
+ # set the PK parameters
+ self.mDD.GAMMA.value = self.mDDGR.GAMMA.value
+ self.mDD.PBDOT.value = self.mDDGR.PBDOT.value
+ self.mDD.OMDOT.value = self.mDDGR.OMDOT.value
+ self.mDD.DR.value = self.mDDGR.DR.value
+ self.mDD.DTH.value = self.mDDGR.DTH.value
+
+ t = pint.simulation.make_fake_toas_uniform(55000, 57000, 100, model=self.mDD)
+ DD_delay = self.mDD.binarymodel_delay(t, None)
+ DDGR_delay = self.mDDGR.binarymodel_delay(t, None)
+ assert np.allclose(DD_delay, DDGR_delay)
+
+ def test_xomdot(self):
+ self.mDD.GAMMA.value = self.mDDGR.GAMMA.value
+ self.mDD.PBDOT.value = self.mDDGR.PBDOT.value
+ self.mDD.OMDOT.value = self.mDDGR.OMDOT.value * 2
+ self.mDD.DR.value = self.mDDGR.DR.value
+ self.mDD.DTH.value = self.mDDGR.DTH.value
+
+ self.mDDGR.XOMDOT.value = self.mDDGR.OMDOT.value
+ t = pint.simulation.make_fake_toas_uniform(55000, 57000, 100, model=self.mDD)
+ DD_delay = self.mDD.binarymodel_delay(t, None)
+ DDGR_delay = self.mDDGR.binarymodel_delay(t, None)
+ assert np.allclose(DD_delay, DDGR_delay)
+
+ def test_xpbdot(self):
+ self.mDD.GAMMA.value = self.mDDGR.GAMMA.value
+ self.mDD.PBDOT.value = self.mDDGR.PBDOT.value * 2
+ self.mDD.OMDOT.value = self.mDDGR.OMDOT.value
+ self.mDD.DR.value = self.mDDGR.DR.value
+ self.mDD.DTH.value = self.mDDGR.DTH.value
+
+ self.mDDGR.XPBDOT.value = self.mDDGR.PBDOT.value
+ t = pint.simulation.make_fake_toas_uniform(55000, 57000, 100, model=self.mDD)
+ DD_delay = self.mDD.binarymodel_delay(t, None)
+ DDGR_delay = self.mDDGR.binarymodel_delay(t, None)
+ assert np.allclose(DD_delay, DDGR_delay)
+
+ def test_ddgrfit_noMTOT(self):
+ # set the PK parameters
+ self.mDD.GAMMA.value = self.mDDGR.GAMMA.value
+ self.mDD.PBDOT.value = self.mDDGR.PBDOT.value
+ self.mDD.OMDOT.value = self.mDDGR.OMDOT.value
+ self.mDD.DR.value = self.mDDGR.DR.value
+ self.mDD.DTH.value = self.mDDGR.DTH.value
+
+ t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, error=1 * u.us, add_noise=True, model=self.mDD
+ )
+
+ fDD = pint.fitter.Fitter.auto(t, self.mDD)
+ fDDGR = pint.fitter.Fitter.auto(t, self.mDDGR)
+ for p in ["ECC", "PB", "A1", "OM", "T0"]:
+ getattr(fDD.model, p).frozen = False
+ getattr(fDDGR.model, p).frozen = False
+
+ fDD.model.GAMMA.frozen = False
+ fDD.model.PBDOT.frozen = False
+ fDD.model.OMDOT.frozen = False
+ fDD.model.SINI.frozen = False
+ fDD.model.M2.frozen = False
+
+ # cannot fit for MTOT yet
+ fDDGR.model.M2.frozen = False
+ fDDGR.model.MTOT.frozen = True
+ fDD.fit_toas()
+ chi2DD = fDD.resids.calc_chi2()
+
+ fDDGR.fit_toas()
+ chi2DDGR = fDDGR.resids.calc_chi2()
+ M2 = copy.deepcopy(fDDGR.model.M2.quantity)
+ # chi^2 values don't have to be super close
+ assert (
+ np.fabs(fDD.model.M2.quantity - fDDGR.model.M2.quantity)
+ < 4 * fDD.model.M2.uncertainty
+ )
+ # perturn M2 and make sure chi^2 gets worse
+ fDDGR.model.M2.quantity += 3 * fDDGR.model.M2.uncertainty
+ fDDGR.resids.update()
+ assert fDDGR.resids.calc_chi2() > chi2DDGR
+ fDDGR.fit_toas()
+ assert np.isclose(fDDGR.resids.calc_chi2(), chi2DDGR, atol=0.1)
+ assert np.isclose(fDDGR.model.M2.quantity, M2)
+
+ def test_ddgrfit(self):
+ t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, model=self.mDDGR, error=1 * u.us, add_noise=True
+ )
+ fDDGR = pint.fitter.Fitter.auto(t, self.mDDGR)
+
+ fDDGR.model.M2.frozen = False
+ fDDGR.model.MTOT.frozen = False
+ # mDDGR.XOMDOT.frozen = False
+ fDDGR.model.XPBDOT.frozen = False
+
+ # start well away from best-fit
+ fDDGR.model.MTOT.quantity += 1e-4 * u.Msun
+ fDDGR.model.M2.quantity += 1e-2 * u.Msun
+ fDDGR.update_resids()
+
+ fDDGR.fit_toas()
+ assert (
+ np.abs(fDDGR.model.MTOT.quantity - (Mp + Mc))
+ < 4 * fDDGR.model.MTOT.uncertainty
+ )
+ assert np.abs(fDDGR.model.M2.quantity - (Mc)) < 4 * fDDGR.model.M2.uncertainty
+ assert np.abs(fDDGR.model.XPBDOT.quantity) < 4 * fDDGR.model.XPBDOT.uncertainty
+
+ def test_design_XOMDOT(self):
+ t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, model=self.mDDGR, error=1 * u.us, add_noise=True
+ )
+ f = pint.fitter.Fitter.auto(t, self.mDDGR)
+ for p in f.model.free_params:
+ getattr(f.model, p).frozen = True
+ f.model.XOMDOT.frozen = False
+ f.model.XOMDOT.value = 0
+ f.fit_toas()
+ XOMDOT = 0 * u.deg / u.yr
+ dXOMDOT = 1e-6 * u.deg / u.yr
+ # move away from minimum
+ f.model.XOMDOT.quantity = XOMDOT + dXOMDOT
+ f.update_resids()
+ M, pars, units = f.model.designmatrix(f.toas)
+ # this is recalculating chi^2 for comparison
+ chi2start = (
+ ((f.resids.calc_time_resids() / f.toas.get_errors()).decompose()) ** 2
+ ).sum()
+ chi2pred = (
+ (
+ (
+ (f.resids.calc_time_resids() - M[:, 1] * dXOMDOT.value / u.Hz)
+ / f.toas.get_errors()
+ ).decompose()
+ )
+ ** 2
+ ).sum()
+ f.model.XOMDOT.quantity = XOMDOT + dXOMDOT * 2
+ f.update_resids()
+ chi2found = f.resids.calc_chi2()
+ assert np.isclose(chi2pred, chi2found, rtol=1e-2)
+
+ def test_design_XPBDOT(self):
+ t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, model=self.mDDGR, error=1 * u.us, add_noise=True
+ )
+ f = pint.fitter.Fitter.auto(t, self.mDDGR)
+ for p in f.model.free_params:
+ getattr(f.model, p).frozen = True
+ f.model.XPBDOT.frozen = False
+ f.model.XPBDOT.value = 0
+ f.fit_toas()
+ XPBDOT = 0 * u.s / u.s
+ dXPBDOT = 1e-14 * u.s / u.s
+ # move away from minimum
+ f.model.XPBDOT.quantity = XPBDOT + dXPBDOT
+ f.update_resids()
+ M, pars, units = f.model.designmatrix(f.toas)
+ # this is recalculating chi^2 for comparison
+ chi2start = (
+ ((f.resids.calc_time_resids() / f.toas.get_errors()).decompose()) ** 2
+ ).sum()
+ chi2pred = (
+ (
+ (
+ (f.resids.calc_time_resids() - M[:, 1] * dXPBDOT.value / u.Hz)
+ / f.toas.get_errors()
+ ).decompose()
+ )
+ ** 2
+ ).sum()
+ f.model.XPBDOT.quantity = XPBDOT + dXPBDOT * 2
+ f.update_resids()
+ chi2found = f.resids.calc_chi2()
+ assert np.isclose(chi2pred, chi2found, rtol=1e-2)
+
+ def test_design_M2(self):
+ t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, model=self.mDDGR, error=1 * u.us, add_noise=True
+ )
+ f = pint.fitter.Fitter.auto(t, self.mDDGR)
+ for p in f.model.free_params:
+ getattr(f.model, p).frozen = True
+ f.model.M2.frozen = False
+ f.fit_toas()
+ M2 = f.model.M2.quantity
+ dM2 = 1e-4 * u.Msun
+ # move away from minimum
+ f.model.M2.quantity = M2 + dM2
+ f.update_resids()
+ M, pars, units = f.model.designmatrix(f.toas)
+ # this is recalculating chi^2 for comparison
+ chi2start = (
+ ((f.resids.calc_time_resids() / f.toas.get_errors()).decompose()) ** 2
+ ).sum()
+ chi2pred = (
+ (
+ (
+ (f.resids.calc_time_resids() - M[:, 1] * dM2.value / u.Hz)
+ / f.toas.get_errors()
+ ).decompose()
+ )
+ ** 2
+ ).sum()
+ f.model.M2.quantity = M2 + dM2 * 2
+ f.update_resids()
+ chi2found = f.resids.calc_chi2()
+ assert np.isclose(chi2pred, chi2found, rtol=1e-2)
+
+ def test_design_MTOT(self):
+ t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, model=self.mDDGR, error=1 * u.us, add_noise=True
+ )
+ f = pint.fitter.Fitter.auto(t, self.mDDGR)
+ for p in f.model.free_params:
+ getattr(f.model, p).frozen = True
+ f.model.MTOT.frozen = False
+ f.fit_toas()
+ MTOT = f.model.MTOT.quantity
+ dMTOT = 1e-5 * u.Msun
+ # move away from minimum
+ f.model.MTOT.quantity = MTOT + dMTOT
+ f.update_resids()
+ M, pars, units = f.model.designmatrix(f.toas)
+ # this is recalculating chi^2 for comparison
+ chi2start = (
+ ((f.resids.calc_time_resids() / f.toas.get_errors()).decompose()) ** 2
+ ).sum()
+ chi2pred = (
+ (
+ (
+ (f.resids.calc_time_resids() - M[:, 1] * dMTOT.value / u.Hz)
+ / f.toas.get_errors()
+ ).decompose()
+ )
+ ** 2
+ ).sum()
+ f.model.MTOT.quantity = MTOT + dMTOT * 2
+ f.update_resids()
+ chi2found = f.resids.calc_chi2()
+ assert np.isclose(chi2pred, chi2found, rtol=1e-2)
diff --git a/tests/test_ddk.py b/tests/test_ddk.py
index 4887e29e9..a176db29f 100644
--- a/tests/test_ddk.py
+++ b/tests/test_ddk.py
@@ -1,8 +1,9 @@
"""Various tests to assess the performance of the DD model."""
+
import copy
import logging
import os
-import unittest
+import pytest
from io import StringIO
import warnings
@@ -43,11 +44,11 @@
"""
-class TestDDK(unittest.TestCase):
+class TestDDK:
"""Compare delays from the ddk model with libstempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parfileJ1713 = "J1713+0747_NANOGrav_11yv0_short.gls.par"
cls.ICRSparfileJ1713 = "J1713+0747_NANOGrav_11yv0_short.gls.ICRS.par"
cls.timJ1713 = "J1713+0747_NANOGrav_11yv0_short.tim"
@@ -67,7 +68,7 @@ def setUpClass(cls):
cls.ECLltres,
cls.ECLltbindelay,
) = np.genfromtxt(
- os.path.join(datadir, cls.parfileJ1713 + ".libstempo"), unpack=True
+ os.path.join(datadir, f"{cls.parfileJ1713}.libstempo"), unpack=True
)[
:, index
]
@@ -78,12 +79,12 @@ def setUpClass(cls):
cls.ICRSltres,
cls.ICRSltbindelay,
) = np.genfromtxt(
- os.path.join(datadir, cls.ICRSparfileJ1713 + ".libstempo"), unpack=True
+ os.path.join(datadir, f"{cls.ICRSparfileJ1713}.libstempo"), unpack=True
)[
:, index
]
- def test_J1713_ECL_binary_delay(self):
+ def test_j1713_ecl_binary_delay(self):
# Calculate delays with PINT
pint_binary_delay = self.ECLmodelJ1713.binarymodel_delay(self.toasJ1713, None)
print(f"{np.abs(pint_binary_delay.value + self.ECLltbindelay).max()}")
@@ -92,7 +93,7 @@ def test_J1713_ECL_binary_delay(self):
% np.abs(pint_binary_delay.value + self.ECLltbindelay).max()
)
- def test_J1713_ICRS_binary_delay(self):
+ def test_j1713_icrs_binary_delay(self):
# Calculate delays with PINT
pint_binary_delay = self.ICRSmodelJ1713.binarymodel_delay(self.toasJ1713, None)
print(f"{np.abs(pint_binary_delay.value + self.ECLltbindelay).max()}")
@@ -101,7 +102,7 @@ def test_J1713_ICRS_binary_delay(self):
% np.abs(pint_binary_delay.value + self.ICRSltbindelay).max()
)
- def test_J1713_ECL(self):
+ def test_j1713_ecl(self):
pint_resids_us = Residuals(
self.toasJ1713, self.ECLmodelJ1713, use_weighted_mean=False
).time_resids.to(u.s)
@@ -112,7 +113,7 @@ def test_J1713_ECL(self):
% np.abs(diff - diff.mean()).max()
)
- def test_J1713_ICRS(self):
+ def test_j1713_icrs(self):
pint_resids_us = Residuals(
self.toasJ1713, self.ICRSmodelJ1713, use_weighted_mean=False
).time_resids.to(u.s)
@@ -142,7 +143,7 @@ def test_change_px(self):
diff = bdelay0 - bdelay1
assert np.all(diff != 0)
- def test_J1713_deriv(self):
+ def test_j1713_deriv(self):
testp = tdu.get_derivative_params(self.ECLmodelJ1713)
delay = self.ECLmodelJ1713.delay(self.toasJ1713)
for p in testp.keys():
@@ -154,11 +155,11 @@ def test_J1713_deriv(self):
par = getattr(self.ECLmodelJ1713, p)
if isinstance(par, boolParameter):
continue
- print("Runing derivative for %s" % ("d_phase_d_" + p))
+ print(f"Runing derivative for d_phase_d_{p}")
ndf = self.ECLmodelJ1713.d_phase_d_param_num(self.toasJ1713, p, testp[p])
adf = self.ECLmodelJ1713.d_phase_d_param(self.toasJ1713, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -173,14 +174,16 @@ def test_J1713_deriv(self):
else:
tol = 1e-3
print(
- "derivative relative diff for %s, %lf"
- % ("d_phase_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_phase_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
continue
- def test_K96(self):
+ def test_k96(self):
modelJ1713 = copy.deepcopy(self.ECLmodelJ1713)
log = logging.getLogger("TestJ1713 Switch of K96")
modelJ1713.K96.value = False
@@ -285,7 +288,3 @@ def test_A1dot_warning():
def test_alternative_solutions():
mECL = get_model(StringIO(temp_par_str + "\n KIN 71.969 1 0.562"))
assert len(mECL.components["BinaryDDK"].alternative_solutions()) == 4
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_dds.py b/tests/test_dds.py
new file mode 100644
index 000000000..d9c8ffdad
--- /dev/null
+++ b/tests/test_dds.py
@@ -0,0 +1,130 @@
+import os
+import pytest
+import io
+import copy
+
+import astropy.units as u
+import numpy as np
+
+from pint.models import get_model, get_model_and_toas
+import pint.simulation
+import pint.fitter
+import pint.toa as toa
+from pint import binaryconvert
+from pinttestdata import datadir
+
+
+def test_DDS_delay():
+ """Make a copy of a DD model and switch to DDS"""
+ parfileB1855 = os.path.join(datadir, "B1855+09_NANOGrav_dfg+12_modified.par")
+ timB1855 = os.path.join(datadir, "B1855+09_NANOGrav_dfg+12.tim")
+ t = toa.get_TOAs(timB1855, ephem="DE405", planets=False, include_bipm=False)
+ mDD = get_model(parfileB1855)
+ with open(parfileB1855) as f:
+ lines = f.readlines()
+ outlines = ""
+ for line in lines:
+ if not (line.startswith("SINI") or line.startswith("BINARY")):
+ outlines += f"{line}"
+ elif line.startswith("SINI"):
+ d = line.split()
+ sini = float(d[1])
+ shapmax = -np.log(1 - sini)
+ outlines += f"SHAPMAX {shapmax}\n"
+ elif line.startswith("BINARY"):
+ outlines += "BINARY DDS\n"
+ mDDS = get_model(io.StringIO(outlines))
+ DD_delay = mDD.binarymodel_delay(t, None)
+ DDS_delay = mDDS.binarymodel_delay(t, None)
+ assert np.allclose(DD_delay, DDS_delay)
+
+
+class TestDDSFit:
+ def setup_method(self):
+ par = """
+ PSR J1640+2224
+ EPHEM DE440
+ CLOCK TT(BIPM2019)
+ UNITS TDB
+ START 53420.4178610467357292
+ FINISH 59070.9677620557378125
+ INFO -f
+ TIMEEPH FB90
+ T2CMETHOD IAU2000B
+ DILATEFREQ N
+ DMDATA N
+ NTOA 14035
+ CHI2 13829.35865041346
+ ELONG 243.989094207621264 0 0.00000001050784977998
+ ELAT 44.058509693269293 0 0.00000001093755261092
+ PMELONG 4.195360836354164 0 0.005558385201250641
+ PMELAT -10.720376071408817 0 0.008030254654447103
+ PX 0.5481471496073655 0 0.18913455936325857
+ ECL IERS2010
+ POSEPOCH 56246.0000000000000000
+ F0 316.12398420312142658 0 4.146144381519702124e-13
+ F1 -2.8153754102544694156e-16 0 1.183507348296795048e-21
+ PEPOCH 56246.0000000000000000
+ CORRECT_TROPOSPHERE N
+ PLANET_SHAPIRO N
+ NE_SW 0.0
+ SWM 0.0
+ DM 18.46722842706362534
+ BINARY DD
+ PB 175.4606618995254901 1 3.282945138073456217e-09
+ PBDOT 0.0
+ A1 55.32972013704859 1 1.4196834101027513e-06
+ A1DOT 1.2273117569157312e-14 1 4.942065359144467e-16
+ ECC 0.0007972619266478279 1 7.475955296127427e-09
+ EDOT -5.4186414842257534e-17 1 1.7814213642177264e-17
+ T0 56188.1561739592320315 1 0.00021082819203119008413
+ OM 50.731718036317336052 1 0.00043256337136868348425
+ OMDOT 0.0
+ M2 0.45422292168060424 1 0.17365015590922117
+ SINI 0.8774306786074643 1 0.048407859962533314
+ A0 0.0
+ B0 0.0
+ GAMMA 0.0
+ DR 0.0
+ DTH 0.0
+ """
+
+ self.m = get_model(io.StringIO(par))
+ self.mDDS = binaryconvert.convert_binary(self.m, "DDS")
+ # use a specific seed for reproducible results
+ np.random.seed(12345)
+ self.t = pint.simulation.make_fake_toas_uniform(
+ 55000, 57000, 100, self.m, error=0.1 * u.us, add_noise=True
+ )
+
+ def test_resids(self):
+ f = pint.fitter.Fitter.auto(self.t, self.m)
+ fDDS = pint.fitter.Fitter.auto(self.t, self.mDDS)
+ assert np.allclose(f.resids.time_resids, fDDS.resids.time_resids)
+
+ def test_ddsfit(self):
+ f = pint.fitter.Fitter.auto(self.t, self.m)
+ f.fit_toas()
+ chi2 = f.resids.calc_chi2()
+ fDDS = pint.fitter.Fitter.auto(self.t, self.mDDS)
+
+ fDDS.fit_toas()
+ chi2DDS = fDDS.resids.calc_chi2()
+ assert np.isclose(
+ 1 - np.exp(-fDDS.model.SHAPMAX.value), f.model.SINI.value, rtol=1e-2
+ )
+ print(f"{chi2} {chi2DDS}")
+ assert np.isclose(chi2, chi2DDS, rtol=1e-2)
+
+ def test_ddsfit_newSHAPMAX(self):
+ f = pint.fitter.Fitter.auto(self.t, self.m)
+ f.fit_toas()
+ chi2 = f.resids.calc_chi2()
+ fDDS = pint.fitter.Fitter.auto(self.t, self.mDDS)
+ fDDS.model.SHAPMAX.quantity += 0.5
+ fDDS.fit_toas()
+ chi2DDS = fDDS.resids.calc_chi2()
+ assert np.isclose(
+ 1 - np.exp(-fDDS.model.SHAPMAX.value), f.model.SINI.value, rtol=1e-2
+ )
+ assert np.isclose(chi2, chi2DDS, rtol=1e-2)
diff --git a/tests/test_derivative_utils.py b/tests/test_derivative_utils.py
index 634ffb2aa..ca35028ae 100644
--- a/tests/test_derivative_utils.py
+++ b/tests/test_derivative_utils.py
@@ -53,9 +53,8 @@ def get_derivative_params(model):
if p.startswith("DMX"):
if not p.startswith("DMX_"):
continue
- else:
- if par.index != 2:
- continue
+ elif par.index != 2:
+ continue
if isinstance(par, pa.MJDParameter) or par.units == u.day:
h = 1e-8
diff --git a/tests/test_design_matrix.py b/tests/test_design_matrix.py
index 63424a042..443888015 100644
--- a/tests/test_design_matrix.py
+++ b/tests/test_design_matrix.py
@@ -14,7 +14,7 @@
class TestDesignMatrix:
- def setup(self):
+ def setup_method(self):
os.chdir(datadir)
self.par_file = "J1614-2230_NANOGrav_12yv3.wb.gls.par"
self.tim_file = "J1614-2230_NANOGrav_12yv3.wb.tim"
@@ -110,3 +110,8 @@ def test_combine_designmatrix_all(self):
]
== 0.0
)
+
+ def test_param_order(self):
+ params_dm = self.model.designmatrix(self.toas, incoffset=False)[1]
+ params_free = self.model.free_params
+ assert params_dm == params_free
diff --git a/tests/test_determinism.py b/tests/test_determinism.py
index bb72729e1..a7ea7820b 100644
--- a/tests/test_determinism.py
+++ b/tests/test_determinism.py
@@ -17,7 +17,7 @@
def test_sampler():
r = []
- for i in range(2):
+ for _ in range(2):
random.seed(0)
numpy.random.seed(0)
s = numpy.random.mtrand.RandomState(0)
@@ -37,10 +37,13 @@ def test_sampler():
phs = 0.0
model = pint.models.get_model(parfile)
- tl = fermi.load_Fermi_TOAs(
- eventfile, weightcolumn=weightcol, minweight=minWeight
+ ts = fermi.get_Fermi_TOAs(
+ eventfile,
+ weightcolumn=weightcol,
+ minweight=minWeight,
+ ephem="DE421",
+ planets=False,
)
- ts = toa.get_TOAs_list(tl, ephem="DE421", planets=False)
# Introduce a small error so that residuals can be calculated
ts.table["error"] = 1.0
ts.filename = eventfile
@@ -80,9 +83,11 @@ def test_sampler():
# fitter.phaseogram()
# samples = sampler.sampler.chain[:, 10:, :].reshape((-1, fitter.n_fit_params))
+ samples = np.transpose(sampler.get_chain(), (1, 0, 2))
# r.append(np.random.randn())
- r.append(sampler.sampler.chain[0])
+ # r.append(sampler.sampler.chain[0])
+ r.append(samples[0])
assert_array_equal(r[0], r[1])
@@ -106,6 +111,7 @@ def log_prob(x, ivar):
sampler.random_state = s
sampler.run_mcmc(p0, 100)
- samples = sampler.chain.reshape((-1, ndim))
+ # samples = sampler.chain.reshape((-1, ndim))
+ samples = np.transpose(sampler.get_chain(), (1, 0, 2)).reshape((-1, ndim))
r.append(samples[0, 0])
assert r[0] == r[1]
diff --git a/tests/test_dmefac_dmequad.py b/tests/test_dmefac_dmequad.py
index c47bf10c1..3987b0cb8 100644
--- a/tests/test_dmefac_dmequad.py
+++ b/tests/test_dmefac_dmequad.py
@@ -1,4 +1,4 @@
-"""Test for the DM uncertaity rescaling DMEFAC and DMEQUAD
+"""Test for the DM uncertainty rescaling DMEFAC and DMEQUAD
"""
from io import StringIO
diff --git a/tests/test_dmx.py b/tests/test_dmx.py
index 84451a032..5f9c3cc4b 100644
--- a/tests/test_dmx.py
+++ b/tests/test_dmx.py
@@ -1,6 +1,6 @@
import logging
import os
-import unittest
+import pytest
import io
import pytest
@@ -36,22 +36,22 @@
"""
-class TestDMX(unittest.TestCase):
+class TestDMX:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parf = os.path.join(datadir, "B1855+09_NANOGrav_dfg+12_DMX.par")
cls.timf = os.path.join(datadir, "B1855+09_NANOGrav_dfg+12.tim")
cls.DMXm = mb.get_model(cls.parf)
cls.toas = toa.get_TOAs(cls.timf, ephem="DE405", include_bipm=False)
- def test_DMX(self):
+ def test_dmx(self):
print("Testing DMX module.")
rs = (
residuals.Residuals(self.toas, self.DMXm, use_weighted_mean=False)
.time_resids.to(u.s)
.value
)
- ltres, _ = np.genfromtxt(self.parf + ".tempo_test", unpack=True)
+ ltres, _ = np.genfromtxt(f"{self.parf}.tempo_test", unpack=True)
resDiff = rs - ltres
assert np.all(
np.abs(resDiff) < 2e-8
@@ -60,11 +60,11 @@ def test_DMX(self):
def test_derivative(self):
log = logging.getLogger("DMX.derivative_test")
p = "DMX_0002"
- log.debug("Runing derivative for %s", "d_delay_d_" + p)
+ log.debug("Running derivative for %s", f"d_delay_d_{p}")
ndf = self.DMXm.d_delay_d_param_num(self.toas, p)
adf = self.DMXm.d_delay_d_param(self.toas, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -72,13 +72,12 @@ def test_derivative(self):
"Derivative test failed at d_delay_d_%s with max relative difference %lf"
% (p, np.nanmax(relative_diff).value)
)
- if p in ["SINI"]:
- tol = 0.7
- else:
- tol = 1e-3
+ tol = 0.7 if p in {"SINI"} else 1e-3
log.debug(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
@@ -160,7 +159,3 @@ def test_multiple_dmxs_explicit_indices_duplicate():
model = get_model(io.StringIO(par))
with pytest.raises(ValueError):
indices = model.add_DMX_ranges([54500, 55500], [55000, 56000], indices=[1, 3])
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_dmxrange_add_sub.py b/tests/test_dmxrange_add_sub.py
index 6949e59e7..ce0dd130f 100644
--- a/tests/test_dmxrange_add_sub.py
+++ b/tests/test_dmxrange_add_sub.py
@@ -53,13 +53,13 @@ def test_unusual_index():
init_len = len(dm_mod.params)
dm_mod.add_DMX_range(mjd_start, mjd_end, index, dmx, frozen=False)
assert len(dm_mod.params) == init_len + 3
- nm = "DMX_" + f"{int(index):04d}"
+ nm = f"DMX_{int(index):04d}"
comp = getattr(dm_mod, nm)
assert comp.value == dmx
- nm = "DMXR1_" + f"{int(index):04d}"
+ nm = f"DMXR1_{int(index):04d}"
comp = getattr(dm_mod, nm)
assert comp.value == mjd_start
- nm = "DMXR2_" + f"{int(index):04d}"
+ nm = f"DMXR2_{int(index):04d}"
comp = getattr(dm_mod, nm)
assert comp.value == mjd_end
@@ -114,13 +114,13 @@ def test_add_DMX():
init_len = len(dm_mod.params)
dm_mod.add_DMX_range(mjd_start, mjd_end, index, dmx, frozen=False)
assert len(dm_mod.params) == init_len + 3
- nm = "DMX_" + f"{int(index):04d}"
+ nm = f"DMX_{index:04d}"
comp = getattr(dm_mod, nm)
assert comp.value == dmx
- nm = "DMXR1_" + f"{int(index):04d}"
+ nm = f"DMXR1_{index:04d}"
comp = getattr(dm_mod, nm)
assert comp.value == mjd_start
- nm = "DMXR2_" + f"{int(index):04d}"
+ nm = f"DMXR2_{index:04d}"
comp = getattr(dm_mod, nm)
assert comp.value == mjd_end
@@ -135,7 +135,7 @@ def test_remove_DMX():
dm_mod.add_DMX_range(mjd_start, mjd_end, index, dmx, frozen=False)
dm_mod.remove_DMX_range(index)
for pn in ["DMX_", "DMXR1_", "DMXR2_"]:
- nm = str(pn) + str(f"{int(index):04d}")
+ nm = str(pn) + str(f"{index:04d}")
assert nm not in dm_mod.params
@@ -161,18 +161,18 @@ def test_model_usage():
dm_mod = model.components["DispersionDMX"]
dm_mod.add_DMX_range(mjd_start, mjd_end, index, dmx, frozen=False)
- nm = "DMX_" + f"{int(index):04d}"
+ nm = f"DMX_{index:04d}"
comp = getattr(dm_mod, nm)
assert comp.value == dmx
- nm = "DMXR1_" + f"{int(index):04d}"
+ nm = f"DMXR1_{index:04d}"
comp = getattr(dm_mod, nm)
assert comp.value == mjd_start
- nm = "DMXR2_" + f"{int(index):04d}"
+ nm = f"DMXR2_{index:04d}"
comp = getattr(dm_mod, nm)
assert comp.value == mjd_end
dm_mod.remove_DMX_range(index)
for pn in ["DMX_", "DMXR1_", "DMXR2_"]:
- nm = pn + str(f"{int(index):04d}")
+ nm = pn + str(f"{index:04d}")
assert nm not in dm_mod.params
init_params = np.array(model.params)
with pytest.raises(ValueError):
diff --git a/tests/test_early_chime_data.py b/tests/test_early_chime_data.py
index f037ed2da..96d4121fb 100644
--- a/tests/test_early_chime_data.py
+++ b/tests/test_early_chime_data.py
@@ -1,6 +1,6 @@
"""Various tests to assess the performance of early CHIME data."""
import os
-import unittest
+import pytest
import pint.models.model_builder as mb
import pint.toa as toa
@@ -8,14 +8,14 @@
from pinttestdata import datadir
-class Test_CHIME_data(unittest.TestCase):
+class Test_CHIME_data:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(self):
+ def setup_class(cls):
os.chdir(datadir)
- self.parfile = "B1937+21.basic.par"
- self.tim = "B1937+21.CHIME.CHIME.NG.N.tim"
+ cls.parfile = "B1937+21.basic.par"
+ cls.tim = "B1937+21.CHIME.CHIME.NG.N.tim"
def test_toa_read(self):
toas = toa.get_TOAs(self.tim, ephem="DE436", planets=False, include_bipm=True)
diff --git a/tests/test_ecorr_average.py b/tests/test_ecorr_average.py
index 184b17510..8c671f601 100644
--- a/tests/test_ecorr_average.py
+++ b/tests/test_ecorr_average.py
@@ -1,4 +1,3 @@
-#! /usr/bin/env python
import os
import astropy.units as u
@@ -25,28 +24,21 @@ def _gen_data(par, tim):
err = m.scaled_sigma(t).to(u.us).value
info = t.get_flag_value("f")
- fout = open(par + ".resids", "w")
- iout = open(par + ".info", "w")
- for i in range(t.ntoas):
- line = "%.10f %.4f %+.8e %.3e 0.0 %s" % (
- mjds[i],
- freqs[i],
- res[i],
- err[i],
- info[i],
- )
- fout.write(line + "\n")
- iout.write(info[i] + "\n")
- fout.close()
- iout.close()
+ with open(f"{par}.resids", "w") as fout:
+ with open(f"{par}.info", "w") as iout:
+ for i in range(t.ntoas):
+ line = "%.10f %.4f %+.8e %.3e 0.0 %s" % (
+ mjds[i],
+ freqs[i],
+ res[i],
+ err[i],
+ info[i],
+ )
+ fout.write(line + "\n")
+ iout.write(info[i] + "\n")
# Requires res_avg in path
- cmd = "cat %s.resids | res_avg -r -t0.0001 -E%s -i%s.info > %s.resavg" % (
- par,
- par,
- par,
- par,
- )
+ cmd = f"cat {par}.resids | res_avg -r -t0.0001 -E{par} -i{par}.info > {par}.resavg"
print(cmd)
# os.system(cmd)
@@ -61,7 +53,7 @@ def test_ecorr_average():
f = GLSFitter(t, m)
# Get comparison resids and uncertainties
mjd, freq, res, err, ophase, chi2, info = np.genfromtxt(
- par + ".resavg", unpack=True
+ f"{par}.resavg", unpack=True
)
resavg_mjd = mjd * u.d
# resavg_freq = freq * u.MHz
diff --git a/tests/test_ell1h.py b/tests/test_ell1h.py
index 374b47280..6167e5eac 100644
--- a/tests/test_ell1h.py
+++ b/tests/test_ell1h.py
@@ -68,7 +68,7 @@ def modelJ0613_STIG():
@pytest.fixture()
def tempo2_res():
parfileJ1853 = "J1853+1303_NANOGrav_11yv0.gls.par"
- return np.genfromtxt(parfileJ1853 + ".tempo2_test", skip_header=1, unpack=True)
+ return np.genfromtxt(f"{parfileJ1853}.tempo2_test", skip_header=1, unpack=True)
def test_J1853(toasJ1853, modelJ1853, tempo2_res):
@@ -105,11 +105,11 @@ def test_derivative(toasJ1853, modelJ1853):
testp["H4"] = 1e-2
testp["STIGMA"] = 1e-2
for p in test_params:
- log.debug("Runing derivative for %s", "d_delay_d_" + p)
+ log.debug("Runing derivative for %s", f"d_delay_d_{p}")
ndf = modelJ1853.d_phase_d_param_num(toasJ1853, p, testp[p])
adf = modelJ1853.d_phase_d_param(toasJ1853, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -117,13 +117,12 @@ def test_derivative(toasJ1853, modelJ1853):
"Derivative test failed at d_delay_d_%s with max relative difference %lf"
% (p, np.nanmax(relative_diff).value)
)
- if p in ["EPS1DOT", "EPS1"]:
- tol = 0.05
- else:
- tol = 1e-3
+ tol = 0.05 if p in ["EPS1DOT", "EPS1"] else 1e-3
log.debug(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
diff --git a/tests/test_erfautils.py b/tests/test_erfautils.py
index d836803e0..c836ff85c 100644
--- a/tests/test_erfautils.py
+++ b/tests/test_erfautils.py
@@ -33,7 +33,7 @@ def test_compare_erfautils_astropy():
dvel = np.sqrt((dopv.vel**2).sum(axis=0))
assert len(dpos) == len(mjds)
# This is just above the level of observed difference
- assert dpos.max() < 0.05 * u.m, "position difference of %s" % dpos.max().to(u.m)
+ assert dpos.max() < 0.05 * u.m, f"position difference of {dpos.max().to(u.m)}"
# This level is what is permitted as a velocity difference from tempo2 in test_times.py
assert dvel.max() < 0.02 * u.mm / u.s, "velocity difference of %s" % dvel.max().to(
u.mm / u.s
@@ -105,13 +105,11 @@ def test_IERS_B_agree_with_IERS_Auto_dX():
assert_equal(A["MJD"][ok_A], B["MJD"][i_B], "MJDs don't make sense")
for tag in ["dX_2000A", "dY_2000A"]:
assert_allclose(
- A[tag + "_B"][ok_A].to(u.marcsec).value,
+ A[f"{tag}_B"][ok_A].to(u.marcsec).value,
B[tag][i_B].to(u.marcsec).value,
atol=1e-5,
rtol=1e-3,
- err_msg="IERS A-derived IERS B {} values don't match current IERS B values".format(
- tag
- ),
+ err_msg=f"IERS A-derived IERS B {tag} values don't match current IERS B values",
)
@@ -143,10 +141,8 @@ def test_IERS_B_agree_with_IERS_Auto():
A[atag][ok_A].to(unit).value,
B[btag][i_B].to(unit).value,
atol=1e-5,
- rtol=1e-5, # should be "close enough"
- err_msg="Inserted IERS B {} values don't match IERS_B_URL {} values".format(
- atag, btag
- ),
+ rtol=1e-5,
+ err_msg=f"Inserted IERS B {atag} values don't match IERS_B_URL {btag} values",
)
@@ -213,10 +209,8 @@ def test_IERS_B_builtin_agree_with_IERS_Auto():
A[atag][ok_A].to(unit).value,
B[btag][i_B].to(unit).value,
atol=1e-5,
- rtol=1e-5, # should be exactly equal
- err_msg="Inserted IERS B {} values don't match IERS_B_FILE {} values".format(
- atag, btag
- ),
+ rtol=1e-5,
+ err_msg=f"Inserted IERS B {atag} values don't match IERS_B_FILE {btag} values",
)
@@ -250,7 +244,5 @@ def test_IERS_B_parameters_loaded_into_IERS_Auto(b_name, a_name):
assert_equal(
A[a_name][ok_A],
B[b_name][i_B],
- err_msg="IERS B parameter {} not copied over IERS A parameter {}".format(
- b_name, a_name
- ),
+ err_msg=f"IERS B parameter {b_name} not copied over IERS A parameter {a_name}",
)
diff --git a/tests/test_event_optimize.py b/tests/test_event_optimize.py
index d900835d2..9fff6688f 100644
--- a/tests/test_event_optimize.py
+++ b/tests/test_event_optimize.py
@@ -1,6 +1,5 @@
-# #!/usr/bin/env python
# This test is DISABLED because event_optimize requires PRESTO to be installed
-# to get the fftfit module. It can be run manually by people who have PRESTO
+# to get the fftfit module. It can be run manually by people who have PRESTO.
# Actually it's not disabled? Unclear what the above is supposed to mean.
import os
import shutil
@@ -80,7 +79,7 @@ def test_parallel(tmp_path):
for i in range(samples1.shape[1]):
assert stats.ks_2samp(samples1[:, i], samples2[:, i])[1] == 1.0
except ImportError:
- pytest.skip
+ pytest.skip(f"Pathos multiprocessing package not found")
finally:
os.chdir(p)
sys.stdout = saved_stdout
@@ -105,11 +104,11 @@ def test_backend(tmp_path):
samples = None
# Running with backend
+ os.chdir(tmp_path)
cmd = f"{eventfile} {parfile} {temfile} --weightcol=PSRJ0030+0451 --minWeight=0.9 --nwalkers=10 --nsteps=50 --burnin=10 --backend --clobber"
event_optimize.maxpost = -9e99
event_optimize.numcalls = 0
event_optimize.main(cmd.split())
-
reader = emcee.backends.HDFBackend("J0030+0451_chains.h5")
samples = reader.get_chain(discard=10)
assert samples is not None
@@ -125,7 +124,32 @@ def test_backend(tmp_path):
assert timestamp == os.path.getmtime("J0030+0451_chains.h5")
except ImportError:
- pytest.skip
+ pytest.skip(f"h5py package not found")
finally:
os.chdir(p)
sys.stdout = saved_stdout
+
+
+def test_autocorr(tmp_path):
+ # Defining a log posterior function based on the emcee tutorials
+ def ln_prob(theta):
+ ln_prior = -0.5 * np.sum((theta - 1.0) ** 2 / 100.0)
+ ln_prob = -0.5 * np.sum(theta**2) + ln_prior
+ return ln_prob
+
+ # Setting a random starting position for 10 walkers with 5 dimenisions
+ coords = np.random.randn(10, 5)
+ nwalkers, ndim = coords.shape
+
+ # Establishing the Sampler
+ nsteps = 500000
+ sampler = emcee.EnsembleSampler(nwalkers, ndim, ln_prob)
+
+ # Running the sampler with the autocorrelation check function from event_optimize
+ autocorr = event_optimize.run_sampler_autocorr(sampler, coords, nsteps, burnin=10)
+
+ # Extracting the samples and asserting that the autocorrelation check
+ # stopped the sampler once convergence was reached
+ samples = np.transpose(sampler.get_chain(discard=10), (1, 0, 2)).reshape((-1, ndim))
+
+ assert len(samples) < nsteps
diff --git a/tests/test_event_optimize_MCMCFitter.py b/tests/test_event_optimize_MCMCFitter.py
index caf579f1c..c07f863ab 100644
--- a/tests/test_event_optimize_MCMCFitter.py
+++ b/tests/test_event_optimize_MCMCFitter.py
@@ -1,8 +1,7 @@
-#!/usr/bin/env python
# This test is DISABLED because event_optimize requires PRESTO to be installed
# to get the fftfit module. It can be run manually by people who have PRESTO
import os
-import unittest
+import pytest
from io import StringIO
@@ -17,8 +16,8 @@
# SMR skipped this test as event_optimize_MCMCFitter isn't used anywhere
# How/why is it different from event_optimize?
-@unittest.skip
-class TestEventOptimizeMCMCFitter(unittest.TestCase):
+@pytest.mark.skip
+class TestEventOptimizeMCMCFitter:
def test_result(self):
import pint.scripts.event_optimize_MCMCFitter as event_optimize
@@ -34,7 +33,3 @@ def test_result(self):
lines = event_optimize.sys.stdout.getvalue()
# Need to add some check here.
event_optimize.sys.stdout = saved_stdout
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_event_optimize_multiple.py b/tests/test_event_optimize_multiple.py
index e67640f82..33c042d06 100644
--- a/tests/test_event_optimize_multiple.py
+++ b/tests/test_event_optimize_multiple.py
@@ -1,9 +1,8 @@
-#!/usr/bin/env python
# This test is DISABLED because event_optimize requires PRESTO to be installed
# to get the fftfit module. It can be run manually by people who have PRESTO
import os
import sys
-import unittest
+import pytest
from io import StringIO
@@ -13,8 +12,8 @@
eventfile = os.path.join(datadir, "evtfiles.txt")
-@unittest.skip
-class TestEventOptimizeMultiple(unittest.TestCase):
+@pytest.mark.skip
+class TestEventOptimizeMultiple:
def test_result(self):
# Delay import because of fftfit
from pint.scripts import event_optimize_multiple
@@ -27,7 +26,3 @@ def test_result(self):
lines = sys.stdout.getvalue()
# Need to add some check here.
sys.stdout = saved_stdout
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_event_toas.py b/tests/test_event_toas.py
index 434a1798e..8e2720cf5 100644
--- a/tests/test_event_toas.py
+++ b/tests/test_event_toas.py
@@ -1,8 +1,11 @@
import os
import pytest
+import numpy as np
+
+from astropy import units as u
from pint.event_toas import read_mission_info_from_heasoft, create_mission_config
-from pint.event_toas import load_fits_TOAs
+from pint.event_toas import get_fits_TOAs, get_NICER_TOAs, _default_uncertainty
from pinttestdata import datadir
@@ -48,7 +51,7 @@ def test_load_events_wrongext_raises():
with pytest.raises(ValueError) as excinfo:
# Not sure how to test that the warning is raised, with Astropy's log system
# Anyway, here I'm testing another error
- load_fits_TOAs(eventfile_nicer_topo, mission="xdsgse", extension=2)
+ get_fits_TOAs(eventfile_nicer_topo, mission="xdsgse", extension=2)
assert msg in str(excinfo.value)
@@ -58,5 +61,25 @@ def test_load_events_wrongext_text_raises():
with pytest.raises(RuntimeError) as excinfo:
# Not sure how to test that the warning is raised, with Astropy's log system
# Anyway, here I'm testing another error
- load_fits_TOAs(eventfile_nicer_topo, mission="xdsgse", extension="dafasdfa")
+ get_fits_TOAs(eventfile_nicer_topo, mission="xdsgse", extension="dafasdfa")
assert msg in str(excinfo.value)
+
+
+def test_for_toa_errors_default():
+ eventfile_nicer = datadir / "ngc300nicer_bary.evt"
+
+ ts = get_NICER_TOAs(
+ eventfile_nicer,
+ )
+ assert np.all(ts.get_errors() == _default_uncertainty["NICER"])
+
+
+@pytest.mark.parametrize("errors", [2, 2 * u.us])
+def test_for_toa_errors_manual(errors):
+ eventfile_nicer = datadir / "ngc300nicer_bary.evt"
+
+ ts = get_NICER_TOAs(
+ eventfile_nicer,
+ errors=errors,
+ )
+ assert np.all(ts.get_errors() == 2 * u.us)
diff --git a/tests/test_eventstats.py b/tests/test_eventstats.py
index d3ba0ff3c..61f2005ed 100755
--- a/tests/test_eventstats.py
+++ b/tests/test_eventstats.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
from numpy.testing import assert_allclose
import pint.eventstats as es
diff --git a/tests/test_explicit_absphase.py b/tests/test_explicit_absphase.py
new file mode 100644
index 000000000..86dd3f18e
--- /dev/null
+++ b/tests/test_explicit_absphase.py
@@ -0,0 +1,67 @@
+from pint.models import get_model, get_model_and_toas
+from pint.simulation import make_fake_toas_uniform
+from pint.residuals import Residuals
+from io import StringIO
+import pytest
+
+par = """
+F0 100
+PEPOCH 50000
+"""
+
+
+@pytest.fixture
+def fake_toas():
+ m = get_model(StringIO(par))
+ t = make_fake_toas_uniform(50000, 50100, 50, m, add_noise=True)
+ return t
+
+
+@pytest.fixture
+def model():
+ return get_model(StringIO(par))
+
+
+def test_add_tzr_toa(model, fake_toas):
+ assert "AbsPhase" not in model.components
+
+ model.add_tzr_toa(fake_toas)
+
+ assert "AbsPhase" in model.components
+ assert hasattr(model, "TZRMJD")
+ assert hasattr(model, "TZRSITE")
+ assert hasattr(model, "TZRFRQ")
+
+ with pytest.raises(ValueError):
+ model.add_tzr_toa(fake_toas)
+
+
+@pytest.mark.parametrize("use_abs_phase", [True, False])
+def test_residuals(model, fake_toas, use_abs_phase):
+ res = Residuals(fake_toas, model, use_abs_phase=use_abs_phase)
+ res.calc_phase_resids()
+ assert ("AbsPhase" in model.components) == use_abs_phase
+
+
+@pytest.mark.parametrize("use_abs_phase", [True, False])
+def test_residuals_2(model, fake_toas, use_abs_phase):
+ res = Residuals(fake_toas, model, use_abs_phase=False)
+ res.calc_phase_resids(use_abs_phase=use_abs_phase)
+ assert ("AbsPhase" in model.components) == use_abs_phase
+
+
+@pytest.mark.parametrize("add_tzr", [True, False])
+def test_get_model(fake_toas, add_tzr):
+ toas_for_tzr = fake_toas if add_tzr else None
+ m = get_model(StringIO(par), toas_for_tzr=toas_for_tzr)
+ assert ("AbsPhase" in m.components) == add_tzr
+
+
+@pytest.mark.parametrize("add_tzr", [True, False])
+def test_get_model_and_toas(fake_toas, add_tzr):
+ timfile = "fake_toas.tim"
+ fake_toas.write_TOA_file(timfile)
+
+ m, t = get_model_and_toas(StringIO(par), timfile, add_tzr_to_model=add_tzr)
+
+ assert ("AbsPhase" in m.components) == add_tzr
diff --git a/tests/test_fake_toas.py b/tests/test_fake_toas.py
index 419659ad9..dfa3d95f4 100644
--- a/tests/test_fake_toas.py
+++ b/tests/test_fake_toas.py
@@ -259,7 +259,7 @@ def test_fake_uniform(t1, t2):
assert np.isclose(r.calc_time_resids().std(), 1 * u.us, rtol=0.2)
-def test_fake_from_timfile():
+def test_fake_from_toas():
# FIXME: this file is unnecessarily huge
m, t = get_model_and_toas(
pint.config.examplefile("B1855+09_NANOGrav_9yv1.gls.par"),
@@ -279,6 +279,29 @@ def test_fake_from_timfile():
)
+@pytest.mark.parametrize("planets", (True, False))
+def test_fake_from_timfile(planets):
+ m = get_model(pint.config.examplefile("NGC6440E.par.good"))
+ t = get_TOAs(pint.config.examplefile("NGC6440E.tim"), planets=planets)
+
+ m.PLANET_SHAPIRO.value = planets
+
+ r = pint.residuals.Residuals(t, m)
+
+ t_sim = pint.simulation.make_fake_toas_fromtim(
+ pint.config.examplefile("NGC6440E.tim"), m, add_noise=True
+ )
+ r_sim = pint.residuals.Residuals(t_sim, m)
+
+ m, t = get_model_and_toas(
+ pint.config.examplefile("B1855+09_NANOGrav_9yv1.gls.par"),
+ pint.config.examplefile("B1855+09_NANOGrav_9yv1.tim"),
+ )
+ assert np.isclose(
+ r.calc_time_resids().std(), r_sim.calc_time_resids().std(), rtol=2
+ )
+
+
def test_fake_highF1():
m = get_model(os.path.join(datadir, "ngc300nicer.par"))
m.F1.quantity *= 10
diff --git a/tests/test_fbx.py b/tests/test_fbx.py
index 3c07dbc61..4cf278b07 100644
--- a/tests/test_fbx.py
+++ b/tests/test_fbx.py
@@ -29,7 +29,7 @@ def modelJ0023():
ltres, ltbindelay = np.genfromtxt(
- parfileJ0023 + ".tempo2_test", skip_header=1, unpack=True
+ f"{parfileJ0023}.tempo2_test", skip_header=1, unpack=True
)
@@ -51,13 +51,13 @@ def test_derivative(modelJ0023, toasJ0023):
testp = tdu.get_derivative_params(modelJ0023)
delay = modelJ0023.delay(toasJ0023)
for p in testp.keys():
- print("Runing derivative for %s", "d_delay_d_" + p)
+ print("Runing derivative for %s", f"d_delay_d_{p}")
if p in ["EPS2", "EPS1"]:
testp[p] = 15
ndf = modelJ0023.d_phase_d_param_num(toasJ0023, p, testp[p])
adf = modelJ0023.d_phase_d_param(toasJ0023, delay, p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -76,8 +76,10 @@ def test_derivative(modelJ0023, toasJ0023):
else:
tol = 1e-3
print(
- "derivative relative diff for %s, %lf"
- % ("d_delay_d_" + p, np.nanmax(relative_diff).value)
+ (
+ "derivative relative diff for %s, %lf"
+ % (f"d_delay_d_{p}", np.nanmax(relative_diff).value)
+ )
)
assert np.nanmax(relative_diff) < tol, msg
else:
diff --git a/tests/test_fd.py b/tests/test_fd.py
index db3aa3053..75d1d7e4c 100644
--- a/tests/test_fd.py
+++ b/tests/test_fd.py
@@ -1,7 +1,7 @@
"""Various tests to assess the performance of the FD model."""
import copy
import os
-import unittest
+import pytest
from io import StringIO
import astropy.units as u
@@ -14,17 +14,17 @@
import pint.toa as toa
-class TestFD(unittest.TestCase):
+class TestFD:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.parf = os.path.join(datadir, "test_FD.par")
cls.timf = os.path.join(datadir, "test_FD.simulate.pint_corrected")
cls.FDm = mb.get_model(cls.parf)
cls.toas = toa.get_TOAs(cls.timf, include_bipm=False)
# libstempo result
- cls.ltres, cls.ltbindelay = np.genfromtxt(cls.parf + ".tempo_test", unpack=True)
+ cls.ltres, cls.ltbindelay = np.genfromtxt(f"{cls.parf}.tempo_test", unpack=True)
- def test_FD(self):
+ def test_fd(self):
print("Testing FD module.")
rs = (
pint.residuals.Residuals(self.toas, self.FDm, use_weighted_mean=False)
@@ -32,7 +32,7 @@ def test_FD(self):
.value
)
resDiff = rs - self.ltres
- # NOTE : This prescision is a lower then 1e-7 seconds level, due to some
+ # NOTE : This precision is a lower then 1e-7 seconds level, due to some
# early parks clock corrections are treated differently.
# TEMPO2: Clock correction = clock0 + clock1 (in the format of general2)
# PINT : Clock correction = toas.table['flags']['clkcorr']
@@ -41,13 +41,13 @@ def test_FD(self):
def test_inf_freq(self):
test_toas = copy.deepcopy(self.toas)
- test_toas.table["freq"][0:5] = np.inf * u.MHz
+ test_toas.table["freq"][:5] = np.inf * u.MHz
fd_delay = self.FDm.components["FD"].FD_delay(test_toas)
assert np.all(
np.isfinite(fd_delay)
), "FD component is not handling infinite frequency right."
assert np.all(
- fd_delay[0:5].value == 0.0
+ fd_delay[:5].value == 0.0
), "FD component did not compute infinite frequency delay right"
d_d_d_fd = self.FDm.d_delay_FD_d_FDX(test_toas, "FD1")
assert np.all(np.isfinite(d_d_d_fd)), (
diff --git a/tests/test_fermiphase.py b/tests/test_fermiphase.py
index e395bd389..1be357cb5 100644
--- a/tests/test_fermiphase.py
+++ b/tests/test_fermiphase.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
import os
import sys
from io import StringIO
@@ -6,12 +5,15 @@
import pytest
+import numpy as np
+
from astropy.io import fits
+from astropy import units as u
import pint.models
import pint.scripts.fermiphase as fermiphase
import pint.toa as toa
-from pint.fermi_toas import load_Fermi_TOAs
+from pint.fermi_toas import get_Fermi_TOAs, _default_uncertainty
from pint.observatory.satellite_obs import get_satellite_observatory
from pinttestdata import datadir
@@ -58,10 +60,13 @@ def test_process_and_accuracy():
modelin = pint.models.get_model(parfile)
get_satellite_observatory("Fermi", ft2file)
- tl = load_Fermi_TOAs(eventfileraw, weightcolumn="PSRJ0030+0451")
- # ts = toa.TOAs(toalist=tl)
- ts = toa.get_TOAs_list(
- tl, include_gps=False, include_bipm=False, planets=False, ephem="DE405"
+ ts = get_Fermi_TOAs(
+ eventfileraw,
+ weightcolumn="PSRJ0030+0451",
+ include_gps=False,
+ include_bipm=False,
+ planets=False,
+ ephem="DE405",
)
iphss, phss = modelin.phase(ts, abs_phase=True)
ph_pint = phss % 1
@@ -80,3 +85,31 @@ def test_process_and_accuracy():
# require absolute phase to be within 500 ns; NB this relies on
# GBT clock corrections since the TZR is referenced there
assert max(abs(resids_mus)) < 0.5
+
+
+def test_for_toa_errors_default():
+ get_satellite_observatory("Fermi", ft2file, overwrite=True)
+ ts = get_Fermi_TOAs(
+ eventfileraw,
+ weightcolumn="PSRJ0030+0451",
+ include_gps=False,
+ include_bipm=False,
+ planets=False,
+ ephem="DE405",
+ )
+ assert np.all(ts.get_errors() == _default_uncertainty)
+
+
+@pytest.mark.parametrize("errors", [2, 2 * u.us])
+def test_for_toa_errors_manual(errors):
+ get_satellite_observatory("Fermi", ft2file, overwrite=True)
+ ts = get_Fermi_TOAs(
+ eventfileraw,
+ weightcolumn="PSRJ0030+0451",
+ include_gps=False,
+ include_bipm=False,
+ planets=False,
+ ephem="DE405",
+ errors=errors,
+ )
+ assert np.all(ts.get_errors() == 2 * u.us)
diff --git a/tests/test_find_clock_file.py b/tests/test_find_clock_file.py
index 7a7355977..ee5d62b5e 100644
--- a/tests/test_find_clock_file.py
+++ b/tests/test_find_clock_file.py
@@ -1,6 +1,7 @@
import os
import astropy.units as u
+import contextlib
import numpy as np
import pytest
from astropy.time import Time
@@ -17,18 +18,12 @@ class Sandbox:
o = Sandbox()
e = os.environ.copy()
- try:
+ with contextlib.suppress(KeyError):
del os.environ["PINT_CLOCK_OVERRIDE"]
- except KeyError:
- pass
- try:
+ with contextlib.suppress(KeyError):
del os.environ["TEMPO"]
- except KeyError:
- pass
- try:
+ with contextlib.suppress(KeyError):
del os.environ["TEMPO2"]
- except KeyError:
- pass
o.override_dir = tmp_path / "override"
o.override_dir.mkdir()
o.repo_dir = tmp_path / "repo"
@@ -56,7 +51,7 @@ class Sandbox:
T2runtime/clock/fake.clk 7.0 ---
"""
)
- o.repo_uri = o.repo_dir.as_uri() + "/"
+ o.repo_uri = f"{o.repo_dir.as_uri()}/"
o.t2_dir = tmp_path / "t2"
o.t2_clock = o.clocks[2]
diff --git a/tests/test_fitter.py b/tests/test_fitter.py
index 1b60b808d..dd1409dc2 100644
--- a/tests/test_fitter.py
+++ b/tests/test_fitter.py
@@ -1,4 +1,3 @@
-#! /usr/bin/env python
import os
from io import StringIO
from copy import deepcopy
@@ -144,10 +143,10 @@ def test_ftest_wb():
A1DOT_Component = "BinaryELL1"
# Test adding A1DOT
Ftest_dict = wb_f.ftest(A1DOT, A1DOT_Component, remove=False, full_output=True)
- assert isinstance(Ftest_dict["ft"], float) or isinstance(Ftest_dict["ft"], bool)
+ assert isinstance(Ftest_dict["ft"], (float, bool))
# Test removing parallax
Ftest_dict = wb_f.ftest(PX, PX_Component, remove=True, full_output=True)
- assert isinstance(Ftest_dict["ft"], float) or isinstance(Ftest_dict["ft"], bool)
+ assert isinstance(Ftest_dict["ft"], (float, bool))
def test_fitsummary_binary():
diff --git a/tests/test_fitter_compare.py b/tests/test_fitter_compare.py
index 1d16c3175..3da6df74d 100644
--- a/tests/test_fitter_compare.py
+++ b/tests/test_fitter_compare.py
@@ -1,4 +1,4 @@
-#! /usr/bin/env python
+import contextlib
from os.path import join
from io import StringIO
import copy
@@ -8,7 +8,6 @@
import pytest
from pinttestdata import datadir
-import pint
from pint.fitter import (
MaxiterReached,
DownhillGLSFitter,
@@ -57,22 +56,16 @@ def test_compare_gls(full_cov, wls):
def test_compare_downhill_wls(wls):
dwls = DownhillWLSFitter(wls.toas, wls.model_init)
- try:
+ with contextlib.suppress(MaxiterReached):
dwls.fit_toas(maxiter=1)
- except MaxiterReached:
- pass
-
assert abs(wls.resids.chi2 - dwls.resids.chi2) < 0.01
@pytest.mark.parametrize("full_cov", [False, True])
def test_compare_downhill_gls(full_cov, wls):
gls = DownhillGLSFitter(wls.toas, wls.model_init)
- try:
+ with contextlib.suppress(MaxiterReached):
gls.fit_toas(maxiter=1, full_cov=full_cov)
- except MaxiterReached:
- pass
-
# Why is this taking a different step from the plain GLS fitter?
assert abs(wls.resids_init.chi2 - gls.resids_init.chi2) < 0.01
assert abs(wls.resids.chi2 - gls.resids.chi2) < 0.01
@@ -81,11 +74,8 @@ def test_compare_downhill_gls(full_cov, wls):
@pytest.mark.parametrize("full_cov", [False, True])
def test_compare_downhill_wb(full_cov, wb):
dwb = WidebandDownhillFitter(wb.toas, wb.model_init)
- try:
+ with contextlib.suppress(MaxiterReached):
dwb.fit_toas(maxiter=1, full_cov=full_cov)
- except MaxiterReached:
- pass
-
assert abs(wb.resids.chi2 - dwb.resids.chi2) < 0.01
@@ -126,18 +116,14 @@ def m_t():
def test_step_different_with_efacs(fitter, m_t):
m, t = m_t
f = fitter(t, m)
- try:
+ with contextlib.suppress(MaxiterReached):
f.fit_toas(maxiter=1)
- except MaxiterReached:
- pass
m2 = copy.deepcopy(m)
m2.EFAC1.value = 1
m2.EFAC2.value = 1
f2 = fitter(t, m2)
- try:
+ with contextlib.suppress(MaxiterReached):
f2.fit_toas(maxiter=1)
- except MaxiterReached:
- pass
for p in m.free_params:
assert getattr(f.model, p).value != getattr(f2.model, p).value
@@ -154,18 +140,14 @@ def test_step_different_with_efacs(fitter, m_t):
def test_step_different_with_efacs_full_cov(fitter, m_t):
m, t = m_t
f = fitter(t, m)
- try:
+ with contextlib.suppress(MaxiterReached):
f.fit_toas(maxiter=1, full_cov=True)
- except MaxiterReached:
- pass
m2 = copy.deepcopy(m)
m2.EFAC1.value = 1
m2.EFAC2.value = 1
f2 = fitter(t, m2)
- try:
+ with contextlib.suppress(MaxiterReached):
f2.fit_toas(maxiter=1, full_cov=True)
- except MaxiterReached:
- pass
for p in m.free_params:
assert getattr(f.model, p).value != getattr(f2.model, p).value
@@ -182,14 +164,10 @@ def test_downhill_same_step(fitter1, fitter2, m_t):
m, t = m_t
f1 = fitter1(t, m)
f2 = fitter2(t, m)
- try:
+ with contextlib.suppress(MaxiterReached):
f1.fit_toas(maxiter=1)
- except MaxiterReached:
- pass
- try:
+ with contextlib.suppress(MaxiterReached):
f2.fit_toas(maxiter=1)
- except MaxiterReached:
- pass
for p in m.free_params:
assert np.isclose(
getattr(f1.model, p).value - getattr(f1.model_init, p).value,
diff --git a/tests/test_fitter_error_checking.py b/tests/test_fitter_error_checking.py
index 75a35a864..53b59241b 100644
--- a/tests/test_fitter_error_checking.py
+++ b/tests/test_fitter_error_checking.py
@@ -1,4 +1,3 @@
-#! /usr/bin/env python
import io
import numpy as np
diff --git a/tests/test_flagging_clustering.py b/tests/test_flagging_clustering.py
index 702e9685e..21f0e426c 100644
--- a/tests/test_flagging_clustering.py
+++ b/tests/test_flagging_clustering.py
@@ -78,7 +78,3 @@ def test_jump_by_cluster_invalidflags(setup_NGC6440E):
add_column=False,
add_flag=1,
)
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_glitch.py b/tests/test_glitch.py
index 138aec560..4d7eac8ad 100644
--- a/tests/test_glitch.py
+++ b/tests/test_glitch.py
@@ -1,6 +1,5 @@
-#! /usr/bin/env python
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -16,9 +15,9 @@
timfile = os.path.join(datadir, "prefixtest.tim")
-class TestGlitch(unittest.TestCase):
+class TestGlitch:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.m = pint.models.get_model(parfile)
cls.t = pint.toa.get_TOAs(timfile, ephem="DE405", include_bipm=False)
cls.f = pint.fitter.WLSFitter(cls.t, cls.m)
@@ -29,7 +28,7 @@ def test_glitch(self):
# Now do the fit
print("Fitting...")
self.f.fit_toas()
- emsg = "RMS of " + self.m.PSR.value + " is too big."
+ emsg = f"RMS of {self.m.PSR.value} is too big."
assert self.f.resids.time_resids.std().to(u.us).value < 950.0, emsg
@pytest.mark.filterwarnings("ignore:invalid value")
@@ -47,17 +46,11 @@ def test_glitch_der(self):
adf = self.m.d_phase_d_param(self.t, delay, param)
param_obj = getattr(self.m, param)
# Get numerical derivative steps.
- if param_obj.units == u.day:
- h = 1e-8
- else:
- h = 1e-2
+ h = 1e-8 if param_obj.units == u.day else 1e-2
ndf = self.m.d_phase_d_param_num(self.t, param, h)
diff = adf - ndf
mean = (adf + ndf) / 2.0
r_diff = diff / mean
- errormsg = (
- "Derivatives for %s is not accurate, max relative difference is"
- % param
- )
+ errormsg = f"Derivatives for {param} is not accurate, max relative difference is"
errormsg += " %lf" % np.nanmax(np.abs(r_diff.value))
assert np.nanmax(np.abs(r_diff.value)) < 1e-3, errormsg
diff --git a/tests/test_global_clock_corrections.py b/tests/test_global_clock_corrections.py
index 06d3dcdde..c517f2293 100644
--- a/tests/test_global_clock_corrections.py
+++ b/tests/test_global_clock_corrections.py
@@ -12,7 +12,7 @@
def test_not_existing(tmp_path, temp_cache):
- url_base = tmp_path.as_uri() + "/"
+ url_base = f"{tmp_path.as_uri()}/"
test_file_name = "test_file"
url = url_base + test_file_name
@@ -27,7 +27,7 @@ def test_not_existing(tmp_path, temp_cache):
def test_existing(tmp_path, temp_cache):
- url_base = tmp_path.as_uri() + "/"
+ url_base = f"{tmp_path.as_uri()}/"
test_file_name = "test_file"
url = url_base + test_file_name
@@ -46,7 +46,7 @@ def test_existing(tmp_path, temp_cache):
def test_update_needed(tmp_path, temp_cache):
- url_base = tmp_path.as_uri() + "/"
+ url_base = f"{tmp_path.as_uri()}/"
test_file_name = "test_file"
url = url_base + test_file_name
@@ -94,7 +94,7 @@ class Sandbox:
# File Update (days) Invalid if older than
{sandbox.filename} 7.0 --- """
)
- sandbox.repo_url = sandbox.repo_path.as_uri() + "/"
+ sandbox.repo_url = f"{sandbox.repo_path.as_uri()}/"
return sandbox
diff --git a/tests/test_gls_fitter.py b/tests/test_gls_fitter.py
index fa2234877..9ad629d36 100644
--- a/tests/test_gls_fitter.py
+++ b/tests/test_gls_fitter.py
@@ -1,7 +1,6 @@
-#! /usr/bin/env python
import json
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -12,11 +11,11 @@
from pinttestdata import datadir
-class TestGLS(unittest.TestCase):
+class TestGLS:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.par = "B1855+09_NANOGrav_9yv1.gls.par"
cls.tim = "B1855+09_NANOGrav_9yv1.tim"
@@ -53,15 +52,13 @@ def test_gls_fitter(self):
if par not in ["ELONG", "ELAT"]
else getattr(self.f.model, par).uncertainty.to(u.rad).value
)
- msg = "Parameter {} does not match T2 for full_cov={}".format(
- par, full_cov
- )
+ msg = f"Parameter {par} does not match T2 for full_cov={full_cov}"
assert np.abs(v - val[0]) <= val[1], msg
assert np.abs(v - val[0]) <= e, msg
assert np.abs(1 - val[1] / e) < 0.1, msg
def test_noise_design_matrix_index(self):
- self.fit(False, True) # get the debug infor
+ self.fit(False, True) # get the debug info
# Test red noise basis
pl_rd = self.f.model.pl_rn_basis_weight_pair(self.f.toas)[0]
p0, p1 = self.f.resids.pl_red_noise_M_index
diff --git a/tests/test_jump.py b/tests/test_jump.py
index f91d6f06e..be53697fc 100644
--- a/tests/test_jump.py
+++ b/tests/test_jump.py
@@ -1,7 +1,7 @@
"""Tests for jump model component """
import logging
import os
-import unittest
+import pytest
import pytest
import astropy.units as u
@@ -14,6 +14,7 @@
from pint.models import parameter as p
from pint.models import PhaseJump
import pint.models.timing_model
+import pint.fitter
class SimpleSetup:
@@ -83,12 +84,14 @@ def test_remove_jump_and_flags(setup_NGC6440E):
# test delete_jump_and_flags
setup_NGC6440E.m.delete_jump_and_flags(setup_NGC6440E.t.table["flags"], 1)
assert len(cp.jumps) == 1
+ f = pint.fitter.Fitter.auto(setup_NGC6440E.t, setup_NGC6440E.m)
# delete last jump
setup_NGC6440E.m.delete_jump_and_flags(setup_NGC6440E.t.table["flags"], 2)
for d in setup_NGC6440E.t.table["flags"][selected_toa_ind2]:
assert "jump" not in d
assert "PhaseJump" not in setup_NGC6440E.m.components
+ f = pint.fitter.Fitter.auto(setup_NGC6440E.t, setup_NGC6440E.m)
def test_jump_params_to_flags(setup_NGC6440E):
@@ -190,9 +193,9 @@ def test_find_empty_masks(setup_NGC6440E):
setup_NGC6440E.m.validate_toas(setup_NGC6440E.t)
-class TestJUMP(unittest.TestCase):
+class TestJUMP:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parf = "B1855+09_NANOGrav_dfg+12_TAI.par"
cls.timf = "B1855+09_NANOGrav_dfg+12.tim"
@@ -202,7 +205,7 @@ def setUpClass(cls):
)
# libstempo calculation
cls.ltres = np.genfromtxt(
- cls.parf + ".tempo_test", unpack=False, names=True, dtype=np.longdouble
+ f"{cls.parf}.tempo_test", unpack=False, names=True, dtype=np.longdouble
)
def test_jump(self):
@@ -216,11 +219,11 @@ def test_jump(self):
def test_derivative(self):
log = logging.getLogger("Jump phase test")
p = "JUMP2"
- log.debug("Runing derivative for %s", "d_delay_d_" + p)
+ log.debug("Runing derivative for %s", f"d_delay_d_{p}")
ndf = self.JUMPm.d_phase_d_param_num(self.toas, p)
adf = self.JUMPm.d_phase_d_param(self.toas, self.JUMPm.delay(self.toas), p)
diff = adf - ndf
- if not np.all(diff.value) == 0.0:
+ if np.all(diff.value) != 0.0:
mean_der = (adf + ndf) / 2.0
relative_diff = np.abs(diff) / np.abs(mean_der)
# print "Diff Max is :", np.abs(diff).max()
@@ -229,7 +232,3 @@ def test_derivative(self):
% (p, np.nanmax(relative_diff).value)
)
assert np.nanmax(relative_diff) < 0.001, msg
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_leapsec.py b/tests/test_leapsec.py
index 665ce4a44..db42ac7e5 100644
--- a/tests/test_leapsec.py
+++ b/tests/test_leapsec.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
import numpy as np
from astropy.time import Time
diff --git a/tests/test_mask_parameter.py b/tests/test_mask_parameter.py
index 03b81f14e..ffb532c7c 100644
--- a/tests/test_mask_parameter.py
+++ b/tests/test_mask_parameter.py
@@ -23,7 +23,7 @@ def test_mjd_mask(toas):
mp = maskParameter("test1", key="mjd", key_value=[54000, 54100])
assert mp.key == "mjd"
assert mp.key_value == [54000, 54100]
- assert mp.value == None
+ assert mp.value is None
select_toas = mp.select_toa_mask(toas)
assert len(select_toas) > 0
raw_selection = np.where(
diff --git a/tests/test_model.py b/tests/test_model.py
index 0295bcec5..a3b0edf71 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,4 +1,4 @@
-#! /usr/bin/env python
+import contextlib
import io
import os
import time
@@ -91,7 +91,7 @@ def test_model():
# run tempo1 also, if the tempo_utils module is available
did_tempo1 = False
- try:
+ with contextlib.suppress(Exception):
import tempo_utils
log.info("Running TEMPO1...")
@@ -114,9 +114,6 @@ def test_model():
% np.fabs(diff_t2_t1).max().value
)
log.info("Std resid diff between T1 and T2: %.2f ns" % diff_t2_t1.std().value)
- except:
- pass
-
if did_tempo1 and not planets:
assert np.fabs(diff_t1).max() < 32.0 * u.ns
diff --git a/tests/test_model_derivatives.py b/tests/test_model_derivatives.py
index 405ddd13a..cc322c851 100644
--- a/tests/test_model_derivatives.py
+++ b/tests/test_model_derivatives.py
@@ -177,6 +177,6 @@ def f(value):
a = model.d_phase_d_param(toas, delay=None, param=param).to_value(1 / units)
b = df(getattr(model, param).value)
if param.startswith("FB"):
- assert np.amax(np.abs(a - b)) / np.amax(np.abs(a) + np.abs(b)) < 1e-6
+ assert np.amax(np.abs(a - b)) / np.amax(np.abs(a) + np.abs(b)) < 1.5e-6
else:
assert_allclose(a, b, atol=1e-4, rtol=1e-4)
diff --git a/tests/test_model_ifunc.py b/tests/test_model_ifunc.py
index 1dc11503b..8ef6a0ef1 100644
--- a/tests/test_model_ifunc.py
+++ b/tests/test_model_ifunc.py
@@ -1,6 +1,5 @@
-#! /usr/bin/env python
import os
-import unittest
+import pytest
import astropy.units as u
@@ -17,9 +16,9 @@
timfile = os.path.join(datadir, "j0007_ifunc.tim")
-class TestIFunc(unittest.TestCase):
+class TestIFunc:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.m = pint.models.get_model(parfile)
cls.t = pint.toa.get_TOAs(timfile, ephem="DE405", include_bipm=False)
@@ -29,9 +28,9 @@ def test_j0007(self):
rs = pint.residuals.Residuals(self.t, self.m)
rms = rs.time_resids.to(u.us).std()
chi2 = rs.reduced_chi2
- emsg = "RMS of " + str(rms.value) + " is too big."
+ emsg = f"RMS of {str(rms.value)} is too big."
assert rms < 2700.0 * u.us, emsg
- emsg = "reduced chi^2 of " + str(chi2) + " is too big."
+ emsg = f"reduced chi^2 of {str(chi2)} is too big."
assert chi2 < 1.1, emsg
# test a fit
@@ -40,9 +39,9 @@ def test_j0007(self):
rs = f.resids
rms = rs.time_resids.to(u.us).std()
chi2 = rs.reduced_chi2
- emsg = "RMS of " + str(rms.value) + " is too big."
+ emsg = f"RMS of {str(rms.value)} is too big."
assert rms < 2700.0 * u.us, emsg
- emsg = "reduced chi^2 of " + str(chi2) + " is too big."
+ emsg = f"reduced chi^2 of {str(chi2)} is too big."
assert chi2 < 1.1, emsg
# the residuals are actually terrible when using linear interpolation,
@@ -50,7 +49,3 @@ def test_j0007(self):
print("Test RMS of a PSR J0007+7303 ephemeris with IFUNCs(0).")
self.m.SIFUNC.quantity = 0
rs = pint.residuals.Residuals(self.t, self.m)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_model_manual.py b/tests/test_model_manual.py
index daa475ff7..3471ce016 100644
--- a/tests/test_model_manual.py
+++ b/tests/test_model_manual.py
@@ -105,7 +105,7 @@ def test_get_model_roundtrip(tmp_dir, parfile):
try:
m_old = get_model(parfile)
except (ValueError, IOError, MissingParameter) as e:
- pytest.skip("Existing code raised an exception {}".format(e))
+ pytest.skip(f"Existing code raised an exception {e}")
fn = join(tmp_dir, "file.par")
with open(fn, "w") as f:
diff --git a/tests/test_model_wave.py b/tests/test_model_wave.py
index 9d7d3e947..1e57d640f 100644
--- a/tests/test_model_wave.py
+++ b/tests/test_model_wave.py
@@ -1,6 +1,5 @@
-#! /usr/bin/env python
import os
-import unittest
+import pytest
import astropy.units as u
@@ -17,9 +16,9 @@
timfile = os.path.join(datadir, "vela_wave.tim")
-class TestWave(unittest.TestCase):
+class TestWave:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.m = pint.models.get_model(parfile)
cls.t = pint.toa.get_TOAs(timfile, ephem="DE405", include_bipm=False)
diff --git a/tests/test_modelconversions.py b/tests/test_modelconversions.py
index 8e935c40b..dddae8d9c 100644
--- a/tests/test_modelconversions.py
+++ b/tests/test_modelconversions.py
@@ -70,8 +70,39 @@ def test_ECL_to_ICRS():
assert np.allclose(r_ECL.resids, r_ICRS.resids)
+def test_ICRS_to_ECL_nouncertainties():
+ # start with ICRS model with no pm uncertainties, get residuals with ECL model, compare
+ model_ICRS = get_model(io.StringIO(modelstring_ICRS))
+ for p in ["PMRA", "PMDEC"]:
+ getattr(model_ICRS, p).frozen = True
+ getattr(model_ICRS, p).uncertainties = None
+
+ toas = pint.simulation.make_fake_toas_uniform(
+ MJDStart, MJDStop, NTOA, model=model_ICRS, error=1 * u.us, add_noise=True
+ )
+ r_ICRS = pint.residuals.Residuals(toas, model_ICRS)
+ r_ECL = pint.residuals.Residuals(toas, model_ICRS.as_ECL())
+ assert np.allclose(r_ECL.resids, r_ICRS.resids)
+ # assert model_ICRS.as_ECL(ecl).ECL.value == ecl
+
+
+def test_ECL_to_ICRS_nouncertainties():
+ # start with ECL model with no pm uncertainties, get residuals with ICRS model, compare
+ model_ECL = get_model(io.StringIO(modelstring_ECL))
+ for p in ["PMELONG", "PMELAT"]:
+ getattr(model_ECL, p).frozen = True
+ getattr(model_ECL, p).uncertainties = None
+
+ toas = pint.simulation.make_fake_toas_uniform(
+ MJDStart, MJDStop, NTOA, model=model_ECL, error=1 * u.us, add_noise=True
+ )
+ r_ECL = pint.residuals.Residuals(toas, model_ECL)
+ r_ICRS = pint.residuals.Residuals(toas, model_ECL.as_ICRS())
+ assert np.allclose(r_ECL.resids, r_ICRS.resids)
+
+
def test_ECL_to_ECL():
- # start with ECL model, get residuals with ECL model with differenct obliquity, compare
+ # start with ECL model, get residuals with ECL model with different obliquity, compare
model_ECL = get_model(io.StringIO(modelstring_ECL))
toas = pint.simulation.make_fake_toas_uniform(
diff --git a/tests/test_modeloverride.py b/tests/test_modeloverride.py
new file mode 100644
index 000000000..424714226
--- /dev/null
+++ b/tests/test_modeloverride.py
@@ -0,0 +1,86 @@
+import contextlib
+import io
+import os
+import time
+import pytest
+
+import astropy.units as u
+from astropy.time import Time
+import numpy as np
+from pint.models import get_model, get_model_and_toas
+import pint.simulation
+
+
+par = """PSRJ J1636-5133
+RAJ 16:35:44.7781433 1 0.05999748816321897513
+DECJ -51:34:18.01262 1 0.73332573676867170105
+F0 2.9404155099936412855 1 0.00000000013195919743
+F1 -1.4209854506981192501e-14 1 8.2230767370490607034e-17
+PEPOCH 60000
+DM 313
+BINARY ELL1
+PB 0.74181505310937273631 1 0.00000018999923507341
+A1 1.5231012457846993008 1 0.00050791514972366278
+TASC 59683.784709068155703 1 0.00004690256150561100"""
+
+
+# add in parameters that exist, parameters that don't, float, bool, and string. Then somem quantities
+@pytest.mark.parametrize(
+ ("k", "v"),
+ [
+ ("F1", -2e-14),
+ ("F2", 1e-12),
+ ("PSR", "ABC"),
+ ("DMDATA", True),
+ ("F1", -1e-10 * u.Hz / u.day),
+ ("F2", -1e-10 * u.Hz / u.day**2),
+ ("PEPOCH", Time(55000, format="pulsar_mjd", scale="tdb")),
+ ("JUMP", "mjd 55000 56000 0.03"),
+ ],
+)
+def test_paroverride(k, v):
+ kwargs = {k: v}
+ m = get_model(io.StringIO(par), **kwargs)
+ if isinstance(v, (str, bool)):
+ if k != "JUMP":
+ assert getattr(m, k).value == v
+ else:
+ assert getattr(m, f"{k}1").value == 0.03
+ assert getattr(m, f"{k}1").key == "mjd"
+ assert getattr(m, f"{k}1").key_value == [55000.0, 56000.0]
+ elif isinstance(v, u.Quantity):
+ assert np.isclose(getattr(m, k).quantity, v)
+ elif isinstance(v, Time):
+ assert getattr(m, k).quantity == v
+ else:
+ assert np.isclose(getattr(m, k).value, v)
+
+
+# these should fail:
+# adding F3 without F2
+# adding an unknown parameter
+# adding an improper value
+# adding a value with incorrect units
+@pytest.mark.parametrize(
+ ("k", "v"), [("F3", -2e-14), ("TEST", -1), ("F1", "test"), ("F1", -1e-10 * u.Hz)]
+)
+def test_paroverride_fails(k, v):
+ kwargs = {k: v}
+ with pytest.raises((AttributeError, ValueError)):
+ m = get_model(io.StringIO(par), **kwargs)
+
+
+# add in parameters that exist, parameters that don't, float, and string
+@pytest.mark.parametrize(("k", "v"), [("F1", -2e-14), ("F2", 1e-12), ("PSR", "ABC")])
+def test_paroverride_withtim(k, v):
+ kwargs = {k: v}
+ m = get_model(io.StringIO(par), **kwargs)
+ t = pint.simulation.make_fake_toas_uniform(50000, 58000, 20, model=m)
+ o = io.StringIO()
+ t.write_TOA_file(o)
+ o.seek(0)
+ m2, t2 = get_model_and_toas(io.StringIO(par), o, **kwargs)
+ if isinstance(v, str):
+ assert getattr(m2, k).value == v
+ else:
+ assert np.isclose(getattr(m2, k).value, v)
diff --git a/tests/test_modelutils.py b/tests/test_modelutils.py
index ec6d46bad..0df565c51 100644
--- a/tests/test_modelutils.py
+++ b/tests/test_modelutils.py
@@ -1,6 +1,6 @@
import logging
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -13,11 +13,11 @@
from pint.modelutils import model_equatorial_to_ecliptic, model_ecliptic_to_equatorial
-class TestEcliptic(unittest.TestCase):
+class TestEcliptic:
"""Test conversion from equatorial <-> ecliptic coordinates, and compare residuals."""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# J0613 is in equatorial
cls.parfileJ0613 = os.path.join(
datadir, "J0613-0200_NANOGrav_dfg+12_TAI_FB90.par"
@@ -50,8 +50,8 @@ def test_to_ecliptic(self):
assert (
"AstrometryEcliptic" in ECLmodelJ0613.components
), "Creation of ecliptic model failed"
- assert not (
- "AstrometryEquatorial" in ECLmodelJ0613.components
+ assert (
+ "AstrometryEquatorial" not in ECLmodelJ0613.components
), "Equatorial model still present"
self.log.debug("Ecliptic model created")
@@ -78,8 +78,8 @@ def test_to_equatorial(self):
assert (
"AstrometryEquatorial" in EQUmodelB1855.components
), "Creation of equatorial model failed"
- assert not (
- "AstrometryEcliptic" in EQUmodelB1855.components
+ assert (
+ "AstrometryEcliptic" not in EQUmodelB1855.components
), "Ecliptic model still present"
self.log.debug("Equatorial model created")
@@ -93,7 +93,3 @@ def test_to_equatorial(self):
% np.nanmax(np.abs(pint_resids - EQUpint_resids)).value
)
assert np.all(np.abs(pint_resids - EQUpint_resids) < 1e-10 * u.s), msg
-
-
-if __name__ == "__main__":
- pass
diff --git a/tests/test_numpy.py b/tests/test_numpy.py
index 75890dc42..cd13c3e8c 100644
--- a/tests/test_numpy.py
+++ b/tests/test_numpy.py
@@ -1,9 +1,8 @@
-#!/usr/bin/env python
import numpy as np
def test_str2longdouble():
- print("You are using numpy %s" % np.__version__)
+ print(f"You are using numpy {np.__version__}")
a = np.longdouble("0.12345678901234567890")
b = float("0.12345678901234567890")
# If numpy is converting to longdouble without loss of
diff --git a/tests/test_observatory.py b/tests/test_observatory.py
index d90de9075..e126b02d8 100644
--- a/tests/test_observatory.py
+++ b/tests/test_observatory.py
@@ -1,23 +1,31 @@
-#!/usr/bin/env python
import io
import os
import json
import astropy.units as u
+import contextlib
import numpy as np
import pytest
from astropy import units as u
from pint.pulsar_mjd import Time
+from pinttestdata import datadir as testdatadir
import pint.observatory
-from pint.observatory import NoClockCorrections, Observatory, get_observatory
+from pint.observatory import (
+ NoClockCorrections,
+ Observatory,
+ get_observatory,
+ compare_t2_observatories_dat,
+ compare_tempo_obsys_dat,
+)
from pint.pulsar_mjd import Time
import pint.observatory.topo_obs
from pint.observatory.topo_obs import (
TopoObs,
load_observatories,
)
+from collections import defaultdict
tobs = ["aro", "ao", "chime", "drao"]
@@ -114,10 +122,8 @@ class Sandbox:
o = Sandbox()
e = os.environ.copy()
- try:
+ with contextlib.suppress(KeyError):
del os.environ["PINT_OBS_OVERRIDE"]
- except KeyError:
- pass
reg = pint.observatory.Observatory._registry.copy()
try:
yield o
@@ -270,7 +276,7 @@ def test_json_observatory_output(sandbox):
gbt_reload = get_observatory("gbt")
for p in gbt_orig.__dict__:
- if not p in ["_clock"]:
+ if p not in ["_clock"]:
assert getattr(gbt_orig, p) == getattr(gbt_reload, p)
@@ -287,7 +293,7 @@ def test_json_observatory_input_latlon(sandbox):
gbt_reload = get_observatory("gbt")
for p in gbt_orig.__dict__:
- if not p in ["location", "_clock"]:
+ if p not in ["location", "_clock"]:
# everything else should be identical
assert getattr(gbt_orig, p) == getattr(gbt_reload, p)
# check distance separately to allow for precision
@@ -315,3 +321,18 @@ def test_valid_past_end():
o = pint.observatory.get_observatory("jbroach")
o.last_clock_correction_mjd()
o.clock_corrections(o._clock[0].time[-1] + 1 * u.d, limits="error")
+
+
+def test_names_and_aliases():
+ na = Observatory.names_and_aliases()
+ assert isinstance(na, dict) and isinstance(na["gbt"], list)
+
+
+def test_compare_t2_observatories_dat():
+ s = compare_t2_observatories_dat(testdatadir)
+ assert isinstance(s, defaultdict)
+
+
+def test_compare_tempo_obsys_dat():
+ s = compare_tempo_obsys_dat(testdatadir / "observatory")
+ assert isinstance(s, defaultdict)
diff --git a/tests/test_observatory_envar.py b/tests/test_observatory_envar.py
index f12c92368..a364478fc 100644
--- a/tests/test_observatory_envar.py
+++ b/tests/test_observatory_envar.py
@@ -1,3 +1,4 @@
+import contextlib
import os
import pytest
import importlib
@@ -15,10 +16,8 @@ class Sandbox:
o = Sandbox()
e = os.environ.copy()
- try:
+ with contextlib.suppress(KeyError):
del os.environ["PINT_OBS_OVERRIDE"]
- except KeyError:
- pass
reg = pint.observatory.Observatory._registry.copy()
o.override_dir = tmp_path / "override"
o.override_dir.mkdir()
diff --git a/tests/test_observatory_metadata.py b/tests/test_observatory_metadata.py
index 5404a4d14..990a71c66 100644
--- a/tests/test_observatory_metadata.py
+++ b/tests/test_observatory_metadata.py
@@ -1,13 +1,14 @@
+import pytest
import logging
-import unittest
+import pytest
import pint.observatory
-class TestObservatoryMetadata(unittest.TestCase):
+class TestObservatoryMetadata:
"""Test handling of observatory metadata"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
# name of an observatory that PINT should know about
# and should have metadata on
cls.pint_obsname = "gbt"
@@ -21,10 +22,7 @@ def test_astropy_observatory(self):
try to instantiate the observatory in PINT from astropy and check their metadata
"""
keck = pint.observatory.get_observatory(self.astropy_obsname)
- msg = (
- "Checking PINT metadata for '%s' failed: 'astropy' not present in '%s'"
- % (self.astropy_obsname, keck.origin)
- )
+ msg = f"Checking PINT metadata for '{self.astropy_obsname}' failed: 'astropy' not present in '{keck.origin}'"
assert "astropy" in keck.origin, msg
def test_pint_observatory(self):
@@ -32,10 +30,7 @@ def test_pint_observatory(self):
try to instantiate the observatory in PINT and check their metadata
"""
gbt = pint.observatory.get_observatory(self.pint_obsname)
- msg = "Checking PINT definition for '%s' failed: metadata is '%s'" % (
- self.pint_obsname,
- gbt.origin,
- )
+ msg = f"Checking PINT definition for '{self.pint_obsname}' failed: metadata is '{gbt.origin}'"
assert (gbt.origin is not None) and (len(gbt.origin) > 0), msg
def test_observatory_replacement(self):
@@ -50,7 +45,7 @@ def test_observatory_replacement(self):
origin="Inserted for testing purposes",
)
obs = pint.observatory.get_observatory(obsname)
- self.assertRaises(
+ pytest.raises(
ValueError,
TopoObs,
obsname,
@@ -58,11 +53,8 @@ def test_observatory_replacement(self):
origin="This is a test - replacement",
)
obs = pint.observatory.get_observatory(obsname)
- msg = (
- "Checking that 'replacement' is not in the metadata for '%s': metadata is '%s'"
- % (obsname, obs.origin)
- )
- assert not ("replacement" in obs.origin), msg
+ msg = f"Checking that 'replacement' is not in the metadata for '{obsname}': metadata is '{obs.origin}'"
+ assert "replacement" not in obs.origin, msg
TopoObs(
obsname,
itrf_xyz=[882589.65, -4924872.32, 3943729.348],
@@ -70,8 +62,5 @@ def test_observatory_replacement(self):
overwrite=True,
)
obs = pint.observatory.get_observatory(obsname)
- msg = (
- "Checking that 'replacement' is now in the metadata for '%s': metadata is '%s'"
- % (obsname, obs.origin)
- )
+ msg = f"Checking that 'replacement' is now in the metadata for '{obsname}': metadata is '{obs.origin}'"
assert "replacement" in obs.origin, msg
diff --git a/tests/test_orbit_phase.py b/tests/test_orbit_phase.py
index 530916f2f..cc73fc4f3 100644
--- a/tests/test_orbit_phase.py
+++ b/tests/test_orbit_phase.py
@@ -1,6 +1,6 @@
-#! /usr/bin/env python
+import pytest
import os
-import unittest
+import pytest
import numpy as np
@@ -9,11 +9,11 @@
from pinttestdata import datadir
-class TestOrbitPhase(unittest.TestCase):
+class TestOrbitPhase:
"""Test orbital phase calculations"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.pJ1855 = "B1855+09_NANOGrav_dfg+12_modified_DD.par"
cls.mJ1855 = m.get_model(cls.pJ1855)
@@ -21,10 +21,10 @@ def setUpClass(cls):
def test_barytimes(self):
ts = t.Time([56789.234, 56790.765], format="mjd")
# Should raise ValueError since not in "tdb"
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.mJ1855.orbital_phase(ts)
# Should raise ValueError since not correct anom
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.mJ1855.orbital_phase(ts.tdb, anom="xxx")
# Should return
phs = self.mJ1855.orbital_phase(ts.tdb, anom="mean")
@@ -33,7 +33,7 @@ def test_barytimes(self):
phs = self.mJ1855.orbital_phase(toas)
assert len(phs) == toas.ntoas
- def test_J1855_nonzero_ecc(self):
+ def test_j1855_nonzero_ecc(self):
ts = self.mJ1855.T0.value + np.linspace(0, self.mJ1855.PB.value, 101)
self.mJ1855.ECC.value = 0.1 # set the eccentricity to nonzero
phs = self.mJ1855.orbital_phase(ts, anom="mean", radians=False)
@@ -49,7 +49,7 @@ def test_J1855_nonzero_ecc(self):
assert np.isclose(phs3[0].value, phs[0].value), "Eccen anom != True anom"
assert phs3[1] != phs[49], "Eccen anom == True anom"
- def test_J1855_zero_ecc(self):
+ def test_j1855_zero_ecc(self):
self.mJ1855.ECC.value = 0.0 # set the eccentricity to zero
self.mJ1855.OM.value = 0.0 # set omega to zero
phs1 = self.mJ1855.orbital_phase(self.mJ1855.T0.value, anom="mean")
@@ -61,7 +61,7 @@ def test_J1855_zero_ecc(self):
phs3 = self.mJ1855.orbital_phase(self.mJ1855.T0.value + 0.1, anom="true")
assert np.isclose(phs3.value, phs1.value), "True anom != Mean anom"
- def test_J1855_ell1(self):
+ def test_j1855_ell1(self):
mJ1855ell1 = m.get_model("B1855+09_NANOGrav_12yv3.wb.gls.par")
phs1 = mJ1855ell1.orbital_phase(mJ1855ell1.TASC.value, anom="mean")
assert np.isclose(phs1.value, 0.0), "Mean anom != 0.0 at TASC as value"
@@ -74,7 +74,7 @@ def test_J1855_ell1(self):
phs3 = mJ1855ell1.orbital_phase(mJ1855ell1.TASC.value + 0.1, anom="true")
assert np.isclose(phs3.value, phs1.value), "True anom != Mean anom"
- def test_J0737(self):
+ def test_j0737(self):
# Find a conjunction which we have confirmed by GBT data and Shapiro delay
mJ0737 = m.get_model("0737A_latest.par")
x = mJ0737.conjunction(55586.25)
diff --git a/tests/test_parametercovariancematrix.py b/tests/test_parameter_covariance_matrix.py
similarity index 100%
rename from tests/test_parametercovariancematrix.py
rename to tests/test_parameter_covariance_matrix.py
diff --git a/tests/test_parameters.py b/tests/test_parameters.py
index d75ad179e..3a6a7b832 100644
--- a/tests/test_parameters.py
+++ b/tests/test_parameters.py
@@ -1,11 +1,12 @@
import os
import pickle
-import unittest
+import pytest
import io
import astropy.time as time
import astropy.units as u
import numpy as np
+import pathlib
import pytest
from numpy.testing import assert_allclose
from pinttestdata import datadir
@@ -137,50 +138,48 @@ def test_units_consistent():
assert pm.uncertainty_value * pm.units == pm.uncertainty
-class TestParameters(unittest.TestCase):
+class TestParameters:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.m = get_model("B1855+09_NANOGrav_dfg+12_modified.par")
cls.mp = get_model("prefixtest.par")
- def test_RAJ(self):
+ def test_raj(self):
"""Check whether the value and units of RAJ parameter are ok"""
units = u.hourangle
value = 18.960109246777776
- self.assertEqual(self.m.RAJ.units, units)
- self.assertEqual(self.m.RAJ.value, value)
- self.assertEqual(self.m.RAJ.quantity, value * units)
+ assert self.m.RAJ.units == units
+ assert self.m.RAJ.value == value
+ assert self.m.RAJ.quantity == value * units
- def test_DECJ(self):
+ def test_decj(self):
"""Check whether the value and units of DECJ parameter are ok"""
units = u.deg
value = 9.72146998888889
- self.assertEqual(self.m.DECJ.units, units)
- self.assertEqual(self.m.DECJ.value, value)
- self.assertEqual(self.m.DECJ.quantity, value * units)
+ assert self.m.DECJ.units == units
+ assert self.m.DECJ.value == value
+ assert self.m.DECJ.quantity == value * units
- def test_F0(self):
+ def test_f0(self):
"""Check whether the value and units of F0 parameter are ok"""
units = u.Hz
value = np.longdouble("186.49408156698235146")
- self.assertEqual(self.m.F0.units, units)
- self.assertTrue(np.isclose(self.m.F0.value, value, atol=1e-19))
- self.assertEqual(self.m.F0.value, value)
+ assert self.m.F0.units == units
+ assert np.isclose(self.m.F0.value, value, atol=1e-19)
+ assert self.m.F0.value == value
- def test_F0_uncertainty(self):
+ def test_f0_uncertainty(self):
uncertainty = 0.00000000000698911818
units = self.m.F0.units
# Test stored uncertainty value
- self.assertTrue(
- np.isclose(self.m.F0.uncertainty.to(units).value, uncertainty, atol=1e-20)
+ assert np.isclose(
+ self.m.F0.uncertainty.to(units).value, uncertainty, atol=1e-20
)
# Test parameter.uncertainty_value returned value
- self.assertTrue(
- np.isclose(self.m.F0.uncertainty_value, uncertainty, atol=1e-20)
- )
+ assert np.isclose(self.m.F0.uncertainty_value, uncertainty, atol=1e-20)
def test_set_new_units(self):
"""Check whether we can set the units to non-standard ones"""
@@ -196,32 +195,28 @@ def test_set_new_units(self):
uncertainty_value_new = 0.00000000000698911818 / 1000.0
# Set it to 186.49 Hz: the 'standard' value
self.m.F0.quantity = value * units
- self.assertTrue(np.isclose(self.m.F0.value, value, atol=1e-13))
+ assert np.isclose(self.m.F0.value, value, atol=1e-13)
# Now change the units
self.m.F0.units = str_unit_new
- self.assertTrue(np.isclose(self.m.F0.value, value_new, atol=1e-13))
+ assert np.isclose(self.m.F0.value, value_new, atol=1e-13)
- self.assertTrue(
- np.isclose(self.m.F0.uncertainty_value, uncertainty_value_new, atol=1e-13)
+ assert np.isclose(
+ self.m.F0.uncertainty_value, uncertainty_value_new, atol=1e-13
)
# Change the units back, and then set them implicitly
# The value will be associate with the new units
self.m.F0.units = str_unit
self.m.F0.quantity = value_new * units_new
self.m.F0.uncertainty = uncertainty_value_new * units_new
- self.assertTrue(np.isclose(self.m.F0.value, value, atol=1e-13))
- self.assertTrue(
- np.isclose(self.m.F0.uncertainty_value, uncertainty_value, atol=1e-20)
- )
+ assert np.isclose(self.m.F0.value, value, atol=1e-13)
+ assert np.isclose(self.m.F0.uncertainty_value, uncertainty_value, atol=1e-20)
# Check the ratio, using the old units as a reference
ratio = self.m.F0.quantity / (value * units)
ratio_uncertainty = self.m.F0.uncertainty / (uncertainty_value * units)
- self.assertTrue(np.isclose(ratio.decompose(u.si.bases), 1.0, atol=1e-13))
+ assert np.isclose(ratio.decompose(u.si.bases), 1.0, atol=1e-13)
- self.assertTrue(
- np.isclose(ratio_uncertainty.decompose(u.si.bases), 1.0, atol=1e-20)
- )
+ assert np.isclose(ratio_uncertainty.decompose(u.si.bases), 1.0, atol=1e-20)
def set_units_fail(self):
"""Setting the unit to a non-compatible unit should fail"""
@@ -229,7 +224,7 @@ def set_units_fail(self):
def test_units(self):
"""Test setting the units"""
- self.assertRaises(u.UnitConversionError, self.set_units_fail)
+ pytest.raises(u.UnitConversionError, self.set_units_fail)
def set_num_to_unit(self):
"""Try to set the numerical value to a unit"""
@@ -241,16 +236,18 @@ def set_num_to_quantity(self):
def test_set_value(self):
"""Try to set the numerical value of a parameter to various things"""
- self.assertRaises(ValueError, self.set_num_to_unit)
- self.assertRaises(ValueError, self.set_num_to_quantity)
+ with pytest.raises(ValueError):
+ self.set_num_to_unit()
+ with pytest.raises(u.UnitTypeError):
+ self.set_num_to_quantity()
- def test_T0(self):
+ def test_t0(self):
"""Test setting T0 to a test value"""
self.m.T0.value = 50044.3322
# I don't understand why this is failing... something about float128
# Does not fail for me (both lines) -- RvH 02/22/2015
- self.assertTrue(np.isclose(self.m.T0.value, 50044.3322))
- self.assertEqual(self.m.T0.value, 50044.3322)
+ assert np.isclose(self.m.T0.value, 50044.3322)
+ assert self.m.T0.value == 50044.3322
def set_num_to_none(self):
"""Set T0 to None"""
@@ -262,8 +259,8 @@ def set_num_to_string(self):
def test_num_to_other(self):
"""Test setting the T0 numerical value to a not-number"""
- self.assertRaises(ValueError, self.set_num_to_none)
- self.assertRaises(ValueError, self.set_num_to_string)
+ pytest.raises(ValueError, self.set_num_to_none)
+ pytest.raises(ValueError, self.set_num_to_string)
def set_OM_to_none(self):
"""Set OM to None"""
@@ -273,24 +270,24 @@ def set_OM_to_time(self):
"""Set OM to a time"""
self.m.OM.value = time.Time(54000, format="mjd")
- def test_OM(self):
+ def test_om(self):
"""Test doing stuff to OM"""
quantity = 10.0 * u.deg
self.m.OM.quantity = quantity
- self.assertEqual(self.m.OM.quantity, quantity)
- self.assertRaises(ValueError, self.set_OM_to_none)
- self.assertRaises(TypeError, self.set_OM_to_time)
+ assert self.m.OM.quantity == quantity
+ pytest.raises(ValueError, self.set_OM_to_none)
+ pytest.raises(TypeError, self.set_OM_to_time)
- def test_PBDOT(self):
+ def test_pbdot(self):
# Check that parameter scaling is working as expected
# Units are not modified, just the value is scaled
self.m.PBDOT.value = 20
- self.assertEqual(self.m.PBDOT.units, u.day / u.day)
- self.assertEqual(self.m.PBDOT.quantity, 20 * 1e-12 * u.day / u.day)
+ assert self.m.PBDOT.units == u.day / u.day
+ assert self.m.PBDOT.quantity == 20 * 1e-12 * u.day / u.day
self.m.PBDOT.value = 1e-11
- self.assertEqual(self.m.PBDOT.units, u.day / u.day)
- self.assertEqual(self.m.PBDOT.quantity, 1e-11 * u.day / u.day)
+ assert self.m.PBDOT.units == u.day / u.day
+ assert self.m.PBDOT.quantity == 1e-11 * u.day / u.day
def test_prefix_value_to_num(self):
"""Test setting the prefix parameter"""
@@ -298,11 +295,11 @@ def test_prefix_value_to_num(self):
units = u.Hz
self.mp.GLF0_2.value = value
- self.assertEqual(self.mp.GLF0_2.quantity, value * units)
+ assert self.mp.GLF0_2.quantity == value * units
value = 50
self.mp.GLF0_2.value = value
- self.assertEqual(self.mp.GLF0_2.quantity, value * units)
+ assert self.mp.GLF0_2.quantity == value * units
def test_prefix_value_str(self):
"""Test setting the prefix parameter from a string"""
@@ -312,7 +309,7 @@ def test_prefix_value_str(self):
self.mp.GLF0_2.value = str_value
- self.assertEqual(self.mp.GLF0_2.value, value * units)
+ assert self.mp.GLF0_2.value == value * units
def set_prefix_value_to_unit_fail(self):
"""Set the prefix parameter to an incompatible value"""
@@ -323,7 +320,7 @@ def set_prefix_value_to_unit_fail(self):
def test_prefix_value_fail(self):
"""Test setting the prefix parameter to an incompatible value"""
- self.assertRaises(ValueError, self.set_prefix_value_to_unit_fail)
+ pytest.raises(ValueError, self.set_prefix_value_to_unit_fail)
def test_prefix_value1(self):
self.mp.GLF0_2.value = 50
@@ -341,9 +338,9 @@ def set_prefix_value1(self):
self.mp.GLF0_2.value = 100 * u.s
def test_prefix_value1(self):
- self.assertRaises(ValueError, self.set_prefix_value1)
+ pytest.raises(ValueError, self.set_prefix_value1)
- def test_START_FINISH_in_par(self):
+ def test_start_finish_in_par(self):
"""
Check that START/FINISH parameters set up/operate properly when
from input file.
@@ -362,10 +359,10 @@ def test_START_FINISH_in_par(self):
assert hasattr(m1, "FINISH")
assert isinstance(m1.FINISH, MJDParameter)
- self.assertEqual(m1.START.value, start_preval)
- self.assertEqual(m1.FINISH.value, finish_preval)
- self.assertEqual(m1.START.frozen, True)
- self.assertEqual(m1.FINISH.frozen, True)
+ assert m1.START.value == start_preval
+ assert m1.FINISH.value == finish_preval
+ assert m1.START.frozen == True
+ assert m1.FINISH.frozen == True
# fit toas and compare with expected/Tempo2 (for WLS) values
fitters = [
@@ -374,23 +371,23 @@ def test_START_FINISH_in_par(self):
]
for fitter in fitters:
fitter.fit_toas()
- self.assertEqual(m1.START.frozen, True)
- self.assertEqual(m1.FINISH.frozen, True)
+ assert m1.START.frozen == True
+ assert m1.FINISH.frozen == True
if fitter.method == "weighted_least_square":
- self.assertAlmostEqual(
- fitter.model.START.value, start_postval, places=9
+ assert fitter.model.START.value == pytest.approx(
+ start_postval, abs=1e-09
)
- self.assertAlmostEqual(
- fitter.model.FINISH.value, finish_postval, places=9
+ assert fitter.model.FINISH.value == pytest.approx(
+ finish_postval, abs=1e-09
)
- self.assertAlmostEqual(
- fitter.model.START.value, fitter.toas.first_MJD.value, places=9
+ assert fitter.model.START.value == pytest.approx(
+ fitter.toas.first_MJD.value, abs=1e-09
)
- self.assertAlmostEqual(
- fitter.model.FINISH.value, fitter.toas.last_MJD.value, places=9
+ assert fitter.model.FINISH.value == pytest.approx(
+ fitter.toas.last_MJD.value, abs=1e-09
)
- def test_START_FINISH_not_in_par(self):
+ def test_start_finish_not_in_par(self):
"""
Check that START/FINISH parameters are added and set up when not
in input file.
@@ -402,8 +399,8 @@ def test_START_FINISH_not_in_par(self):
start_postval = 53478.2858714192 # from Tempo2
finish_postval = 54187.5873241699 # from Tempo2
- self.assertTrue(hasattr(m, "START"))
- self.assertTrue(hasattr(m, "FINISH"))
+ assert hasattr(m, "START")
+ assert hasattr(m, "FINISH")
# fit toas and compare with expected/Tempo2 (for WLS) values
fitters = [
@@ -412,43 +409,40 @@ def test_START_FINISH_not_in_par(self):
]
for fitter in fitters:
fitter.fit_toas()
- self.assertTrue(hasattr(fitter.model, "START"))
- self.assertTrue(hasattr(fitter.model, "FINISH"))
- self.assertEqual(fitter.model.START.frozen, True)
- self.assertEqual(fitter.model.FINISH.frozen, True)
+ assert hasattr(fitter.model, "START")
+ assert hasattr(fitter.model, "FINISH")
+ assert fitter.model.START.frozen == True
+ assert fitter.model.FINISH.frozen == True
if fitter.method == "weighted_least_square":
- self.assertAlmostEqual(
- fitter.model.START.value, start_postval, places=9
+ assert fitter.model.START.value == pytest.approx(
+ start_postval, abs=1e-09
)
- self.assertAlmostEqual(
- fitter.model.FINISH.value, finish_postval, places=9
+ assert fitter.model.FINISH.value == pytest.approx(
+ finish_postval, abs=1e-09
)
- self.assertAlmostEqual(
- fitter.model.START.value, fitter.toas.first_MJD.value, places=9
+ assert fitter.model.START.value == pytest.approx(
+ fitter.toas.first_MJD.value, abs=1e-09
)
- self.assertAlmostEqual(
- fitter.model.FINISH.value, fitter.toas.last_MJD.value, places=9
+ assert fitter.model.FINISH.value == pytest.approx(
+ fitter.toas.last_MJD.value, abs=1e-09
)
- def test_START_FINISH_notfrozen(self):
+ def test_start_finish_notfrozen(self):
"""
check that when the START/FINISH parameters
are added as unfrozen it warns and fixes
"""
- # check initialization after fitting for .par file without START/FINISH
- with open("NGC6440E.par") as f:
- s = f.read()
- s += "START 54000 1\nFINISH 55000 1\n"
+ s = pathlib.Path("NGC6440E.par").read_text() + "START 54000 1\nFINISH 55000 1\n"
# make sure that it warns
with pytest.warns(UserWarning, match=r"cannot be unfrozen"):
m = get_model(io.StringIO(s))
- self.assertTrue(hasattr(m, "START"))
- self.assertTrue(hasattr(m, "FINISH"))
+ assert hasattr(m, "START")
+ assert hasattr(m, "FINISH")
# make sure that it freezes
- self.assertEqual(m.START.frozen, True)
- self.assertEqual(m.FINISH.frozen, True)
+ assert m.START.frozen == True
+ assert m.FINISH.frozen == True
@pytest.mark.parametrize(
diff --git a/tests/test_parfile.py b/tests/test_parfile.py
index 73e2dc620..bef686772 100644
--- a/tests/test_parfile.py
+++ b/tests/test_parfile.py
@@ -1,5 +1,3 @@
-#! /usr/bin/env python
-
import tempfile
import pytest
diff --git a/tests/test_parfile_writing.py b/tests/test_parfile_writing.py
index 95d41f68b..5e5dc2a9b 100644
--- a/tests/test_parfile_writing.py
+++ b/tests/test_parfile_writing.py
@@ -24,16 +24,14 @@ def test_parfile_write(tmp_path):
for p in modelB1855.params:
par = getattr(modelB1855, p)
# Change value for 20%
- if isinstance(par.value, numbers.Number):
+ if isinstance(par.value, numbers.Number) and not isinstance(
+ par, mp.MJDParameter
+ ):
ov = par.value
- if isinstance(par, mp.MJDParameter):
- continue
- else:
- par.value = ov * 0.8
+ par.value = ov * 0.8
res = Residuals(toasB1855, modelB1855, use_weighted_mean=False).time_resids.to(u.s)
- f = open(out_parfile, "w")
- f.write(modelB1855.as_parfile())
- f.close()
+ with open(out_parfile, "w") as f:
+ f.write(modelB1855.as_parfile())
read_model = mb.get_model(out_parfile)
read_res = Residuals(toasB1855, read_model, use_weighted_mean=False).time_resids.to(
u.s
diff --git a/tests/test_parfile_writing_format.py b/tests/test_parfile_writing_format.py
index 49ce86d83..f2795a1d5 100644
--- a/tests/test_parfile_writing_format.py
+++ b/tests/test_parfile_writing_format.py
@@ -16,9 +16,9 @@ def test_SWM():
m = get_model(os.path.join(datadir, "B1855+09_NANOGrav_9yv1.gls.par"))
assert (
- ("SWM" in m.as_parfile())
- and not ("SWM" in m.as_parfile(format="tempo"))
- and not ("SWM" in m.as_parfile(format="tempo2"))
+ "SWM" in m.as_parfile()
+ and "SWM" not in m.as_parfile(format="tempo")
+ and "SWM" not in m.as_parfile(format="tempo2")
)
@@ -30,8 +30,8 @@ def test_CHI2():
f = fitter.WLSFitter(toas=t, model=m)
assert "CHI2" in f.model.as_parfile()
- assert not ("CHI2" in f.model.as_parfile(format="tempo2"))
- assert not ("CHI2" in f.model.as_parfile(format="tempo"))
+ assert "CHI2" not in f.model.as_parfile(format="tempo2")
+ assert "CHI2" not in f.model.as_parfile(format="tempo")
def test_T2CMETHOD():
@@ -62,12 +62,12 @@ def test_STIGMA():
"""Should get changed to VARSIGMA for TEMPO/TEMPO2"""
m = get_model(os.path.join(datadir, "J0613-0200_NANOGrav_9yv1_ELL1H_STIG.gls.par"))
assert (
- ("STIGMA" in m.as_parfile())
- and not ("VARSIGMA" in m.as_parfile())
- and not ("STIGMA" in m.as_parfile(format="tempo"))
- and ("VARSIGMA" in m.as_parfile(format="tempo"))
- and not ("STIGMA" in m.as_parfile(format="tempo2"))
- and ("VARSIGMA" in m.as_parfile(format="tempo2"))
+ "STIGMA" in m.as_parfile()
+ and "VARSIGMA" not in m.as_parfile()
+ and "STIGMA" not in m.as_parfile(format="tempo")
+ and "VARSIGMA" in m.as_parfile(format="tempo")
+ and "STIGMA" not in m.as_parfile(format="tempo2")
+ and "VARSIGMA" in m.as_parfile(format="tempo2")
)
@@ -75,12 +75,12 @@ def test_A1DOT():
"""Should get changed to XDOT for TEMPO/TEMPO2"""
m = get_model(os.path.join(datadir, "J1600-3053_test.par"))
assert (
- ("A1DOT" in m.as_parfile())
- and not ("XDOT" in m.as_parfile())
- and not ("A1DOT" in m.as_parfile(format="tempo"))
- and ("XDOT" in m.as_parfile(format="tempo"))
- and not ("A1DOT" in m.as_parfile(format="tempo2"))
- and ("XDOT" in m.as_parfile(format="tempo2"))
+ "A1DOT" in m.as_parfile()
+ and "XDOT" not in m.as_parfile()
+ and "A1DOT" not in m.as_parfile(format="tempo")
+ and "XDOT" in m.as_parfile(format="tempo")
+ and "A1DOT" not in m.as_parfile(format="tempo2")
+ and "XDOT" in m.as_parfile(format="tempo2")
)
diff --git a/tests/test_parunits.py b/tests/test_parunits.py
new file mode 100644
index 000000000..21f15927c
--- /dev/null
+++ b/tests/test_parunits.py
@@ -0,0 +1,57 @@
+import pytest
+from astropy import units as u
+from pint.models.timing_model import UnknownParameter
+from pint.utils import get_unit
+
+
+@pytest.mark.parametrize(
+ "p",
+ [
+ "F0",
+ "F1",
+ "DM",
+ "DMX_0001",
+ "DMX_0002",
+ "DMXR1_0001",
+ "POSEPOCH",
+ "PMRA",
+ "PMELONG",
+ "FB0",
+ "FB12",
+ "PB",
+ "RA",
+ "A1",
+ "M2",
+ "EDOT",
+ "ECC",
+ "OM",
+ "T0",
+ "TASC",
+ "XDOT",
+ "EFAC",
+ "EQUAD",
+ "JUMP1",
+ ],
+)
+def test_par_units(p):
+ unit = get_unit(p)
+ print(f"{p}: {unit}")
+ assert isinstance(unit, u.UnitBase)
+
+
+# strings should have units of None
+@pytest.mark.parametrize(
+ "p",
+ [
+ "PLANET_SHAPIRO",
+ ],
+)
+def test_par_units_none(p):
+ unit = get_unit(p)
+ print(f"{p}: {unit}")
+ assert unit is None
+
+
+def test_par_units_fails():
+ with pytest.raises(UnknownParameter):
+ unit = get_unit("notapar")
diff --git a/tests/test_pb.py b/tests/test_pb.py
new file mode 100644
index 000000000..a5b918b90
--- /dev/null
+++ b/tests/test_pb.py
@@ -0,0 +1,52 @@
+import numpy as np
+from astropy import units as u, constants as c
+from astropy.time import Time
+import os
+from pinttestdata import datadir
+import pytest
+
+from pint.models import get_model
+
+
+def test_fb():
+ # with FB terms
+ m = get_model(os.path.join(datadir, "J0023+0923_NANOGrav_11yv0.gls.par"))
+ assert np.isclose(m.pb()[0].to_value(u.d), (1 / m.FB0.quantity).to_value(u.d))
+
+
+@pytest.mark.parametrize(
+ "t",
+ [
+ Time(55555, format="pulsar_mjd", scale="tdb", precision=9),
+ 55555 * u.d,
+ 55555.0,
+ 55555,
+ "55555",
+ np.array([55555, 55556]),
+ ],
+)
+def test_fb_input(t):
+ # with FB terms
+ m = get_model(os.path.join(datadir, "J0023+0923_NANOGrav_11yv0.gls.par"))
+ pb, pberr = m.pb(t)
+
+
+def test_pb():
+ m = get_model(os.path.join(datadir, "J0437-4715.par"))
+ assert np.isclose(m.pb()[0].to_value(u.d), m.PB.quantity.to_value(u.d))
+
+
+@pytest.mark.parametrize(
+ "t",
+ [
+ Time(55555, format="pulsar_mjd", scale="tdb", precision=9),
+ 55555 * u.d,
+ 55555.0,
+ 55555,
+ "55555",
+ np.array([55555, 55556]),
+ ],
+)
+def test_pb(t):
+ m = get_model(os.path.join(datadir, "J0437-4715.par"))
+ pb, pberr = m.pb(t)
diff --git a/tests/test_phase_commands.py b/tests/test_phase_commands.py
index 53947ffda..f20764a63 100644
--- a/tests/test_phase_commands.py
+++ b/tests/test_phase_commands.py
@@ -1,6 +1,5 @@
-#!/usr/bin/python
import os
-import unittest
+import pytest
import pint.models
import pint.toa
@@ -13,7 +12,7 @@
timfile = os.path.join(datadir, "NGC6440E_PHASETEST.tim")
-class TestAbsPhase(unittest.TestCase):
+class TestAbsPhase:
def test_phase_commands(self):
model = pint.models.get_model(parfile)
toas = pint.toa.get_TOAs(timfile)
diff --git a/tests/test_phase_offset.py b/tests/test_phase_offset.py
new file mode 100644
index 000000000..d28fc26c0
--- /dev/null
+++ b/tests/test_phase_offset.py
@@ -0,0 +1,104 @@
+"Tests for the PhaseOffset component."
+
+import numpy as np
+import io
+
+from pint.models import get_model_and_toas, get_model
+from pint.models.phase_offset import PhaseOffset
+from pint.residuals import Residuals
+from pint.simulation import make_fake_toas_uniform
+from pint.fitter import WLSFitter
+
+from pinttestdata import datadir
+
+parfile = datadir / "NGC6440E.par"
+timfile = datadir / "NGC6440E.tim"
+
+
+def test_phase_offset():
+ simplepar = """
+ ELAT 5.6 1
+ ELONG 3.2 1
+ F0 100 1
+ PEPOCH 50000
+ PHOFF 0.2 1
+ """
+ m = get_model(io.StringIO(simplepar))
+
+ assert hasattr(m, "PHOFF") and m.PHOFF.value == 0.2
+
+ t = make_fake_toas_uniform(
+ startMJD=50000,
+ endMJD=50500,
+ ntoas=100,
+ model=m,
+ add_noise=True,
+ )
+
+ res = Residuals(t, m)
+ assert res.reduced_chi2 < 1.5
+
+ # Simulated TOAs should have zero residual mean..
+ mean0 = res.calc_phase_mean().value
+ assert np.isclose(mean0, 0, atol=1e-3)
+
+ res.model.PHOFF.value = -0.1
+ mean1 = res.calc_phase_mean().value
+ # The new mean should be equal to (0.2 - -0.1).
+ assert np.isclose(mean1 - mean0, 0.3, atol=1e-3)
+ assert np.isclose(
+ np.average(res.calc_phase_resids(), weights=1 / t.get_errors() ** 2),
+ 0.3,
+ atol=1e-3,
+ )
+
+ # subtract_mean option should be ignored.
+ assert np.allclose(
+ res.calc_phase_resids(subtract_mean=True),
+ res.calc_phase_resids(subtract_mean=False),
+ )
+
+ ftr = WLSFitter(t, m)
+ ftr.fit_toas(maxiter=3)
+ assert ftr.resids.reduced_chi2 < 1.5
+
+
+def test_phase_offset_designmatrix():
+ m, t = get_model_and_toas(parfile, timfile)
+
+ M1, pars1, units1 = m.designmatrix(t, incoffset=True)
+
+ assert "Offset" in pars1
+
+ offset_idx = pars1.index("Offset")
+ M1_offset = M1[:, offset_idx]
+
+ po = PhaseOffset()
+ m.add_component(po)
+ m.PHOFF.frozen = False
+
+ M2, pars2, units2 = m.designmatrix(t, incoffset=True)
+
+ assert "Offset" not in pars2
+ assert "PHOFF" in pars2
+
+ phoff_idx = pars2.index("PHOFF")
+ M2_phoff = M2[:, phoff_idx]
+
+ assert np.allclose(M1_offset, M2_phoff)
+ assert M1.shape == M2.shape
+
+
+def test_fit_real_data():
+ m, t = get_model_and_toas(parfile, timfile)
+
+ po = PhaseOffset()
+ m.add_component(po)
+ m.PHOFF.frozen = False
+
+ ftr = WLSFitter(t, m)
+ ftr.fit_toas(maxiter=3)
+
+ assert ftr.resids.reduced_chi2 < 1.5
+
+ assert abs(ftr.model.PHOFF.value) <= 0.5
diff --git a/tests/test_photonphase.py b/tests/test_photonphase.py
index 7b747650d..048281eff 100644
--- a/tests/test_photonphase.py
+++ b/tests/test_photonphase.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
import os
import pytest
diff --git a/tests/test_piecewise.py b/tests/test_piecewise.py
index caceaa3f1..0809a6328 100644
--- a/tests/test_piecewise.py
+++ b/tests/test_piecewise.py
@@ -1,6 +1,5 @@
-#! /usr/bin/env python
import os
-import unittest
+import pytest
import astropy.units as u
import pytest
@@ -18,9 +17,9 @@
timfile = os.path.join(datadir, "piecewise.tim")
-class TestPiecewise(unittest.TestCase):
+class TestPiecewise:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.m = pint.models.get_model(parfile)
cls.m2 = pint.models.get_model(parfile2)
cls.t = pint.toa.get_TOAs(timfile)
diff --git a/tests/test_pintbary.py b/tests/test_pintbary.py
index 7def071a0..e40ef49cc 100644
--- a/tests/test_pintbary.py
+++ b/tests/test_pintbary.py
@@ -1,6 +1,5 @@
-#!/usr/bin/env python
import sys
-import unittest
+import pytest
import numpy as np
from io import StringIO
@@ -8,14 +7,14 @@
import pint.scripts.pintbary as pintbary
-class TestPintBary(unittest.TestCase):
+class TestPintBary:
def test_result(self):
saved_stdout, sys.stdout = sys.stdout, StringIO("_")
cmd = "56000.0 --ra 12h22m33.2s --dec 19d21m44.2s --obs gbt --ephem DE405"
pintbary.main(cmd.split())
v = sys.stdout.getvalue()
# Check that last value printed is the barycentered time
- self.assertTrue(np.isclose(float(v.split()[-1]), 56000.0061691189))
+ assert np.isclose(float(v.split()[-1]), 56000.0061691189)
sys.stdout = saved_stdout
@@ -41,6 +40,3 @@ def test_result(self):
# pintbary.output('some_text')
# self.assertEquals(plus_one.sys.stdout.getvalue(), 'some_text\n')
# plus_one.sys.stdout = saved_stdout
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_pintempo.py b/tests/test_pintempo.py
index 1f8963350..f83529cac 100644
--- a/tests/test_pintempo.py
+++ b/tests/test_pintempo.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
import os
import sys
@@ -6,14 +5,18 @@
from pint.scripts import pintempo
from pinttestdata import datadir
+import matplotlib
+import pytest
parfile = os.path.join(datadir, "NGC6440E.par")
timfile = os.path.join(datadir, "NGC6440E.tim")
-def test_result():
+@pytest.mark.parametrize("gls", ["", "--gls"])
+def test_pintempo(gls):
+ matplotlib.use("Agg")
saved_stdout, sys.stdout = sys.stdout, StringIO("_")
- cmd = "{0} {1}".format(parfile, timfile)
+ cmd = f"{parfile} {timfile} --plot {gls} --plotfile _test_pintempo.pdf --outfile _test_pintempo.out"
pintempo.main(cmd.split())
lines = sys.stdout.getvalue()
v = 999.0
diff --git a/tests/test_pintk.py b/tests/test_pintk.py
index d8d63cd28..a548e7a2a 100644
--- a/tests/test_pintk.py
+++ b/tests/test_pintk.py
@@ -1,6 +1,5 @@
-#!/usr/bin/env python
import os
-import unittest
+import pytest
import pytest
from io import StringIO
@@ -15,14 +14,10 @@
@pytest.mark.skipif(
"DISPLAY" not in os.environ, reason="Needs an X server, xvfb counts"
)
-class TestPintk(unittest.TestCase):
+class TestPintk:
def test_result(self):
saved_stdout, pintk.sys.stdout = pintk.sys.stdout, StringIO("_")
cmd = "--test {0} {1}".format(parfile, timfile)
pintk.main(cmd.split())
lines = pintk.sys.stdout.getvalue()
pintk.sys.stdout = saved_stdout
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_pldmnoise.py b/tests/test_pldmnoise.py
index cbf979728..9a27450c2 100644
--- a/tests/test_pldmnoise.py
+++ b/tests/test_pldmnoise.py
@@ -10,7 +10,9 @@
Add test that checks for same results between PINT and enterprise
"""
+
import astropy.units as u
+import contextlib
import io
import numpy as np
import pint.fitter as fitters
@@ -147,10 +149,8 @@ def test_PLRedNoise_recovery():
# refit model
f2 = fitters.DownhillGLSFitter(toas, model)
- try:
+ with contextlib.suppress(fitters.InvalidModelParameters, fitters.StepProblem):
f2.fit_toas()
- except (fitters.InvalidModelParameters, fitters.StepProblem):
- pass
f2.model.validate()
f2.model.validate_toas(toas)
r2 = Residuals(toas, f2.model)
diff --git a/tests/test_precision.py b/tests/test_precision.py
index 2421e5075..eddfd6670 100644
--- a/tests/test_precision.py
+++ b/tests/test_precision.py
@@ -44,38 +44,36 @@
time_eps = (np.finfo(float).eps * u.day).to(u.ns)
-leap_sec_mjds = set(
- [
- 41499,
- 41683,
- 42048,
- 42413,
- 42778,
- 43144,
- 43509,
- 43874,
- 44239,
- 44786,
- 45151,
- 45516,
- 46247,
- 47161,
- 47892,
- 48257,
- 48804,
- 49169,
- 49534,
- 50083,
- 50630,
- 51179,
- 53736,
- 54832,
- 56109,
- 57204,
- 57754,
- ]
-)
-leap_sec_days = set([d - 1 for d in leap_sec_mjds])
+leap_sec_mjds = {
+ 41499,
+ 41683,
+ 42048,
+ 42413,
+ 42778,
+ 43144,
+ 43509,
+ 43874,
+ 44239,
+ 44786,
+ 45151,
+ 45516,
+ 46247,
+ 47161,
+ 47892,
+ 48257,
+ 48804,
+ 49169,
+ 49534,
+ 50083,
+ 50630,
+ 51179,
+ 53736,
+ 54832,
+ 56109,
+ 57204,
+ 57754,
+}
+leap_sec_days = {d - 1 for d in leap_sec_mjds}
near_leap_sec_days = list(
sorted([d - 1 for d in leap_sec_days] + [d + 1 for d in leap_sec_days])
)
@@ -84,8 +82,7 @@
@composite
def possible_leap_sec_days(draw):
y = draw(integers(2017, 2050))
- s = draw(booleans())
- if s:
+ if s := draw(booleans()):
m = Time(datetime(y, 6, 30, 0, 0, 0), scale="tai").mjd
else:
m = Time(datetime(y, 12, 31, 0, 0, 0), scale="tai").mjd
@@ -327,9 +324,7 @@ def test_time_from_longdouble(scale, i_f):
@pytest.mark.parametrize("format", ["mjd", "pulsar_mjd"])
def test_time_from_longdouble_utc(format, i_f):
i, f = i_f
- assume(
- not (format == "pulsar_mjd" and i in leap_sec_days and (1 - f) * 86400 < 1e-9)
- )
+ assume(format != "pulsar_mjd" or i not in leap_sec_days or (1 - f) * 86400 >= 1e-9)
t = Time(val=i, val2=f, format=format, scale="utc")
ld = np.longdouble(i) + np.longdouble(f)
assert (
@@ -479,9 +474,7 @@ def test_posvel_respects_longdouble():
def test_time_from_mjd_string_accuracy_vs_longdouble(format, i_f):
i, f = i_f
mjd = np.longdouble(i) + np.longdouble(f)
- assume(
- not (format == "pulsar_mjd" and i in leap_sec_days and (1 - f) * 86400 < 1e-9)
- )
+ assume(format != "pulsar_mjd" or i not in leap_sec_days or (1 - f) * 86400 >= 1e-9)
s = str(mjd)
t = Time(val=i, val2=f, format=format, scale="utc")
assert (
@@ -507,9 +500,7 @@ def test_time_from_mjd_string_roundtrip_very_close(format):
@pytest.mark.parametrize("format", ["pulsar_mjd", "mjd"])
def test_time_from_mjd_string_roundtrip(format, i_f):
i, f = i_f
- assume(
- not (format == "pulsar_mjd" and i in leap_sec_days and (1 - f) * 86400 < 1e-9)
- )
+ assume(format != "pulsar_mjd" or i not in leap_sec_days or (1 - f) * 86400 >= 1e-9)
t = Time(val=i, val2=f, format=format, scale="utc")
assert (
abs(t - time_from_mjd_string(time_to_mjd_string(t), format=format)).to(u.ns)
@@ -683,15 +674,17 @@ def _test_erfa_conversion(leap, i_f):
else:
assert jd_change.to(u.ns) < 2 * u.ns
# assert jd_change.to(u.ns) < 1 * u.ns
- return
- i_o, f_o = day_frac(jd1 - erfa.DJM0, jd2)
+ # @abhisrkckl : There was a return statement above which made the
+ # code below unreachable. I have removed it and commented out the code below.
- mjd_change = abs((i_o - i_i) + (f_o - f_i)) * u.day
- if leap:
- assert mjd_change.to(u.s) < 1 * u.s
- else:
- assert mjd_change.to(u.ns) < 1 * u.ns
+ # i_o, f_o = day_frac(jd1 - erfa.DJM0, jd2)
+
+ # mjd_change = abs((i_o - i_i) + (f_o - f_i)) * u.day
+ # if leap:
+ # assert mjd_change.to(u.s) < 1 * u.s
+ # else:
+ # assert mjd_change.to(u.ns) < 1 * u.ns
# Try to nail down ERFA behaviour
@@ -760,7 +753,7 @@ def test_mjd_jd_round_trip(i_f):
@given(reasonable_mjd())
def test_mjd_jd_pulsar_round_trip(i_f):
i, f = i_f
- assume(not (i in leap_sec_days and (1 - f) * 86400 < 1e-9))
+ assume(i not in leap_sec_days or (1 - f) * 86400 >= 1e-9)
with decimal.localcontext(decimal.Context(prec=40)):
jds = mjds_to_jds_pulsar(*i_f)
assert_closer_than_ns(jds_to_mjds_pulsar(*jds), i_f, 1)
diff --git a/tests/test_prefixparaminheritance.py b/tests/test_prefix_param_inheritance.py
similarity index 100%
rename from tests/test_prefixparaminheritance.py
rename to tests/test_prefix_param_inheritance.py
diff --git a/tests/test_priors.py b/tests/test_priors.py
index 99d349a9d..0f6b8a64d 100644
--- a/tests/test_priors.py
+++ b/tests/test_priors.py
@@ -1,6 +1,5 @@
-#!/usr/bin/env python
import os
-import unittest
+import pytest
import numpy as np
from scipy.stats import norm
@@ -17,9 +16,9 @@
from pinttestdata import datadir
-class TestPriors(unittest.TestCase):
+class TestPriors:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.m = pint.models.get_model("B1855+09_NANOGrav_dfg+12_modified.par")
@@ -85,7 +84,3 @@ def test_gaussian_bounded(self):
)
# Test that integral is 1.0, not safe since using _rv private var
assert np.isclose(self.m.M2.prior._rv.cdf(0.6), 1.0)
-
-
-# if __name__ == '__main__':
-# unittest.main()
diff --git a/tests/test_process_parfile.py b/tests/test_process_parfile.py
index 732a1319f..8aa570250 100644
--- a/tests/test_process_parfile.py
+++ b/tests/test_process_parfile.py
@@ -1,4 +1,3 @@
-#! /usr/bin/env python
"""Tests for par file parsing, pintify parfile and value fill up
"""
@@ -140,14 +139,14 @@ def test_pintify_parfile_alises():
def test_pintify_parfile_nonrepeat_with_alise_repeating():
- aliases_par = base_par + "LAMBDA 2 1 3"
+ aliases_par = f"{base_par}LAMBDA 2 1 3"
m = ModelBuilder()
with pytest.raises(TimingModelError):
m._pintify_parfile(StringIO(aliases_par))
def test_pintify_parfile_unrecognize():
- wrong_par = base_par + "UNKNOWN 2 1 1"
+ wrong_par = f"{base_par}UNKNOWN 2 1 1"
m = ModelBuilder()
pint_dict, original_name, unknow = m._pintify_parfile(StringIO(wrong_par))
assert unknow == {"UNKNOWN": ["2 1 1"]}
diff --git a/tests/test_pulsar_mjd.py b/tests/test_pulsar_mjd.py
index 8c1185e93..4fe96cecd 100644
--- a/tests/test_pulsar_mjd.py
+++ b/tests/test_pulsar_mjd.py
@@ -1,6 +1,5 @@
-#!/usr/bin/env python
import os
-import unittest
+import pytest
import numpy as np
from pint.pulsar_mjd import Time
@@ -8,9 +7,9 @@
from pinttestdata import datadir
-class TestPsrMjd(unittest.TestCase):
+class TestPsrMjd:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.leap_second_days = ["2016-12-31T12:00:00", "2015-06-30T12:00:00"]
cls.normal_days = ["2016-12-11T12:00:00", "2015-06-02T12:00:00"]
diff --git a/tests/test_pulsar_position.py b/tests/test_pulsar_position.py
index c4eb1b6ca..929daa204 100644
--- a/tests/test_pulsar_position.py
+++ b/tests/test_pulsar_position.py
@@ -1,7 +1,7 @@
"""Various tests to assess the performance of the PINT position.
"""
import os
-import unittest
+import pytest
import numpy as np
@@ -9,9 +9,9 @@
from pinttestdata import datadir
-class TestPulsarPosition(unittest.TestCase):
+class TestPulsarPosition:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
# This uses ELONG and ELAT
cls.m1 = mb.get_model("B1855+09_NANOGrav_9yv1.gls.par")
@@ -33,7 +33,7 @@ def test_ssb_2_psr(self):
p1 = self.m1.ssb_to_psb_xyz_ICRS(epoch=self.t)
p2 = self.m2.ssb_to_psb_xyz_ICRS(epoch=self.t)
- self.assertTrue(np.max(np.abs(p1 - p2)) < 1e-6)
+ assert np.max(np.abs(p1 - p2)) < 1e-6
# Switch on PM
self.m1.PMELONG.value = PMELONG_v
@@ -44,7 +44,7 @@ def test_ssb_2_psr(self):
p1 = self.m1.ssb_to_psb_xyz_ICRS(epoch=self.t)
p2 = self.m2.ssb_to_psb_xyz_ICRS(epoch=self.t)
- self.assertTrue(np.max(np.abs(p1 - p2)) < 1e-7)
+ assert np.max(np.abs(p1 - p2)) < 1e-7
def test_parse_line(self):
self.m1.ELONG.from_parfile_line(
@@ -67,7 +67,7 @@ def test_parse_line(self):
self.m1.PMELONG.from_parfile_line("PMELONG -3.2701 1 0.0141")
self.m1.PMELAT.from_parfile_line("PMELAT -5.0982 1 0.0291")
- self.assertTrue(np.isclose(self.m1.ELONG.value, ELONG_v))
- self.assertTrue(np.isclose(self.m1.ELAT.value, ELAT_v))
- self.assertTrue(np.isclose(self.m1.PMELONG.value, PMELONG_v))
- self.assertTrue(np.isclose(self.m1.PMELAT.value, PMELAT_v))
+ assert np.isclose(self.m1.ELONG.value, ELONG_v)
+ assert np.isclose(self.m1.ELAT.value, ELAT_v)
+ assert np.isclose(self.m1.PMELONG.value, PMELONG_v)
+ assert np.isclose(self.m1.PMELAT.value, PMELAT_v)
diff --git a/tests/test_pulse_number.py b/tests/test_pulse_number.py
index acb0e8b1d..6a52b1e83 100644
--- a/tests/test_pulse_number.py
+++ b/tests/test_pulse_number.py
@@ -1,4 +1,4 @@
-#!/usr/bin/python
+import contextlib
from io import StringIO
import os
import pytest
@@ -148,13 +148,9 @@ def test_fitter_respects_pulse_numbers(fitter, model, fake_toas):
fitter(t, m_2, track_mode="capybara")
f_1 = fitter(t, m_2, track_mode="nearest")
- try:
+ with contextlib.suppress(ValueError):
f_1.fit_toas()
assert abs(f_1.model.F0.quantity - model.F0.quantity) > 0.1 * delta_f
- except ValueError:
- # convergence fails for Downhill fitters
- pass
-
f_2 = fitter(t, m_2, track_mode="use_pulse_numbers")
f_2.fit_toas()
assert abs(f_2.model.F0.quantity - model.F0.quantity) < 0.01 * delta_f
diff --git a/tests/test_random_models.py b/tests/test_random_models.py
index f1fcfd364..2cc716bb0 100644
--- a/tests/test_random_models.py
+++ b/tests/test_random_models.py
@@ -72,3 +72,18 @@ def test_random_models_wb(fitter):
# this is a bit stochastic, but I see typically < 1e-4 cycles for this
# for the uncertainty at 59000
assert np.all(dphase.std(axis=0) < 1e-4)
+
+
+@pytest.mark.parametrize("clock", ["UTC(NIST)", "TT(BIPM2021)", "TT(BIPM2015)"])
+def test_random_model_clock(clock):
+ m, t = get_model_and_toas(
+ os.path.join(datadir, "NGC6440E.par"), os.path.join(datadir, "NGC6440E.tim")
+ )
+ m.CLOCK.value = clock
+ if "BIPM" in clock:
+ assert simulation.get_fake_toa_clock_versions(m)[
+ "bipm_version"
+ ] == m.CLOCK.value.replace("TT(", "").replace(")", "")
+ assert simulation.get_fake_toa_clock_versions(m)["include_bipm"]
+ else:
+ assert ~simulation.get_fake_toa_clock_versions(m)["include_bipm"]
diff --git a/tests/test_residuals.py b/tests/test_residuals.py
index 021b17ad4..d4ed15d0d 100644
--- a/tests/test_residuals.py
+++ b/tests/test_residuals.py
@@ -3,6 +3,7 @@
import os
from io import StringIO
+import copy
import astropy.units as u
import numpy as np
@@ -14,6 +15,7 @@
from pint.fitter import GLSFitter, WidebandTOAFitter, WLSFitter
from pint.models import get_model
from pint.residuals import CombinedResiduals, Residuals, WidebandTOAResiduals
+import pint.residuals
from pint.toa import get_TOAs
from pint.simulation import make_fake_toas_uniform
@@ -47,7 +49,7 @@ def wideband_fake():
class TestResidualBuilding:
- def setup(self):
+ def setup_method(self):
self.model = get_model("J1614-2230_NANOGrav_12yv3.wb.gls.par")
# self.toa = make_fake_toas(57000,59000,20,model=self.model)
self.toa = get_TOAs("J1614-2230_NANOGrav_12yv3.wb.tim")
@@ -353,3 +355,85 @@ def test_wideband_chi2_updating(wideband_fake):
assert 1e-3 > abs(WidebandTOAResiduals(toas, f2.model).chi2 - ftc2) > 1e-5
ftc2 = f2.fit_toas(maxiter=10)
assert_allclose(WidebandTOAResiduals(toas, f2.model).chi2, ftc2)
+
+
+@pytest.mark.parametrize(
+ "d",
+ [
+ 1 * u.s,
+ np.ones(20) * u.s,
+ TimeDelta(np.ones(20) * u.s),
+ ],
+)
+def test_toa_adjust(d):
+ model = get_model(
+ StringIO(
+ """
+ PSRJ J1234+5678
+ ELAT 0
+ ELONG 0
+ DM 10
+ F0 1
+ PEPOCH 58000
+ EFAC mjd 57000 58000 2
+ """
+ )
+ )
+ toas = make_fake_toas_uniform(57000, 59000, 20, model=model, error=1 * u.us)
+ toas2 = copy.deepcopy(toas)
+ toas2.adjust_TOAs(d)
+ if isinstance(d, TimeDelta):
+ dcomp = d.sec * u.s
+ else:
+ dcomp = d
+ r = Residuals(toas, model, subtract_mean=False)
+ r2 = Residuals(toas2, model, subtract_mean=False)
+ assert np.allclose(r2.calc_time_resids() - r.calc_time_resids(), dcomp, atol=1e-3)
+
+
+def test_resid_mean():
+ model = get_model(
+ StringIO(
+ """
+ PSRJ J1234+5678
+ ELAT 0
+ ELONG 0
+ DM 10
+ F0 1
+ PEPOCH 58000
+ EFAC mjd 57000 58000 2
+ """
+ )
+ )
+ toas = make_fake_toas_uniform(57000, 59000, 20, model=model, error=1 * u.us)
+ r = Residuals(toas, model, subtract_mean=False)
+ r2 = Residuals(toas, model)
+ full = r.calc_time_resids()
+ w = 1.0 / (r.get_data_error().value ** 2)
+ mean, _ = pint.residuals.weighted_mean(full, w)
+ assert mean == r.calc_time_mean()
+ assert r.calc_time_mean() == r2.calc_time_mean()
+
+
+def test_resid_mean_phase():
+ model = get_model(
+ StringIO(
+ """
+ PSRJ J1234+5678
+ ELAT 0
+ ELONG 0
+ DM 10
+ F0 1
+ PEPOCH 58000
+ EFAC mjd 57000 58000 2
+ """
+ )
+ )
+ toas = make_fake_toas_uniform(57000, 59000, 20, model=model, error=1 * u.us)
+ r = Residuals(toas, model, subtract_mean=False)
+ r2 = Residuals(toas, model)
+ full = r.calc_phase_resids()
+ w = 1.0 / (r.get_data_error().value ** 2)
+ mean, _ = pint.residuals.weighted_mean(full, w)
+ assert mean == r.calc_phase_mean()
+ assert r.calc_phase_mean() == r2.calc_phase_mean()
diff --git a/tests/test_solar_system_body.py b/tests/test_solar_system_body.py
index e12f9182d..3367c45b3 100644
--- a/tests/test_solar_system_body.py
+++ b/tests/test_solar_system_body.py
@@ -1,6 +1,5 @@
-#!/usr/bin/env python
import os
-import unittest
+import pytest
import astropy.time as time
import numpy as np
@@ -10,16 +9,10 @@
from pint.solar_system_ephemerides import objPosVel, objPosVel_wrt_SSB
from pinttestdata import datadir
-# Hack to support FileNotFoundError in Python 2
-try:
- FileNotFoundError
-except NameError:
- FileNotFoundError = IOError
-
-class TestSolarSystemDynamic(unittest.TestCase):
+class TestSolarSystemDynamic:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
MJDREF = 2400000.5
J2000_JD = 2451545.0
diff --git a/tests/test_solar_wind.py b/tests/test_solar_wind.py
index 6f071ada0..8aa6cf9ee 100644
--- a/tests/test_solar_wind.py
+++ b/tests/test_solar_wind.py
@@ -262,8 +262,8 @@ def test_swfit_swx():
model1.SWM.value = 0
# SWX model
model3 = copy.deepcopy(model0)
- model3.add_component(SolarWindDispersionX())
model3.remove_component("SolarWindDispersion")
+ model3.add_component(SolarWindDispersionX())
model3.SWXR1_0001.value = t.get_mjds().min().value
model3.SWXR2_0001.value = t.get_mjds().max().value
model3.SWXP_0001.value = 2
diff --git a/tests/test_tcb2tdb.py b/tests/test_tcb2tdb.py
new file mode 100644
index 000000000..34e7373a1
--- /dev/null
+++ b/tests/test_tcb2tdb.py
@@ -0,0 +1,83 @@
+"""Tests for `pint.models.tcb_conversion` and the `tcb2tdb` script."""
+
+import os
+from copy import deepcopy
+from io import StringIO
+
+import numpy as np
+import pytest
+
+from pint.models.model_builder import ModelBuilder
+from pint.models.tcb_conversion import convert_tcb_tdb
+from pint.scripts import tcb2tdb
+
+simplepar = """
+PSR PSRTEST
+RAJ 17:48:52.75 1
+DECJ -20:21:29.0 1
+F0 61.485476554 1
+F1 -1.181D-15 1
+PEPOCH 53750.000000
+POSEPOCH 53750.000000
+DM 223.9 1
+SOLARN0 0.00
+BINARY BT
+T0 53750
+A1 100.0 1 0.1
+ECC 1.0
+OM 0.0
+PB 10.0
+EPHEM DE436
+CLK TT(BIPM2017)
+UNITS TCB
+TIMEEPH FB90
+T2CMETHOD TEMPO
+CORRECT_TROPOSPHERE N
+PLANET_SHAPIRO Y
+DILATEFREQ N
+"""
+
+
+@pytest.mark.parametrize("backwards", [True, False])
+def test_convert_units(backwards):
+ with pytest.raises(ValueError):
+ m = ModelBuilder()(StringIO(simplepar))
+
+ m = ModelBuilder()(StringIO(simplepar), allow_tcb="raw")
+ f0_tcb = m.F0.value
+ pb_tcb = m.PB.value
+ convert_tcb_tdb(m, backwards=backwards)
+ assert m.UNITS.value == ("TCB" if backwards else "TDB")
+ assert np.isclose(m.F0.value / f0_tcb, pb_tcb / m.PB.value)
+
+
+def test_convert_units_roundtrip():
+ m = ModelBuilder()(StringIO(simplepar), allow_tcb="raw")
+ m_ = deepcopy(m)
+ convert_tcb_tdb(m, backwards=False)
+ convert_tcb_tdb(m, backwards=True)
+
+ for par in m.params:
+ p = getattr(m, par)
+ p_ = getattr(m_, par)
+ if p.value is None:
+ assert p_.value is None
+ elif isinstance(p.value, str):
+ assert getattr(m, par).value == getattr(m_, par).value
+ else:
+ assert np.isclose(getattr(m, par).value, getattr(m_, par).value)
+
+
+def test_tcb2tdb(tmp_path):
+ tmppar1 = tmp_path / "tmp1.par"
+ tmppar2 = tmp_path / "tmp2.par"
+ with open(tmppar1, "w") as f:
+ f.write(simplepar)
+
+ cmd = f"{tmppar1} {tmppar2}"
+ tcb2tdb.main(cmd.split())
+
+ assert os.path.isfile(tmppar2)
+
+ m2 = ModelBuilder()(tmppar2)
+ assert m2.UNITS.value == "TDB"
diff --git a/tests/test_templates.py b/tests/test_templates.py
index 18bd0a1ad..21b293bc8 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -631,38 +631,41 @@ def sn(x, p, index=0):
lct = lctemplate.LCTemplate([p1, p2], [0.4, 0.6])
assert lct.check_gradient(quiet=True, seed=0)
assert lct.check_derivative()
- return
- p1.free[1] = False
- assert p1.gradient(ph[:10], free=True).shape[0] == 1
-
- p1 = lceprimitives.LCEVonMises(p=[0.05, 0.1], slope=[0, 0.1])
- p2 = lceprimitives.LCEVonMises(p=[0.05, 0.4], slope=[0.05, 0.05])
- assert np.all(p2.slope == 0.05)
- assert np.allclose(p1(ph, log10_ens=en), vm(ph, [0.05, 0.1 + (en - 3) * 0.1]))
- lct = lctemplate.LCTemplate([p1, p2], [0.4, 0.6])
- assert lct.is_energy_dependent()
- assert p1.gradient(ph[:10], free=True).shape[0] == 2
- assert p1.gradient(ph[:10], free=False).shape[0] == 4
- assert lct.check_gradient(quiet=True, seed=0)
- assert lct.check_derivative()
-
- # check integral
- ens = np.linspace(2, 4, 21)
- integrals = p1.integrate(0, 0.5, log10_ens=ens)
- quads = [
- quad(lambda x: p1(x, log10_ens=ens[i]), 0, 0.5)[0] for i in range(len(ens))
- ]
- assert np.allclose(quads, integrals)
-
- # check overflow
- p1 = lceprimitives.LCEVonMises()
- p1.p[0] = p1.bounds[0][0]
- assert not np.isnan(p1(p1.p[-1]))
-
- # check fast Bessel function
- q = np.logspace(-1, np.log10(700), 201)
- I0 = lcprimitives.FastBessel(order=0)
- I1 = lcprimitives.FastBessel(order=1)
- assert np.all(np.abs(I0(q) / i0(q) - 1) < 1e-2)
- assert np.all(np.abs(I1(q) / i1(q) - 1) < 1e-2)
+ # @abhisrkckl : There was a return statement above which made the code below unreachable.
+ # I have removed the return statement and commented out the below code.
+ # The following tests don't actually work. I am not sure why.
+
+ # p1.free[1] = False
+ # assert p1.gradient(ph[:10], free=True).shape[0] == 1
+
+ # p1 = lceprimitives.LCEVonMises(p=[0.05, 0.1], slope=[0, 0.1])
+ # p2 = lceprimitives.LCEVonMises(p=[0.05, 0.4], slope=[0.05, 0.05])
+ # assert np.all(p2.slope == 0.05)
+ # assert np.allclose(p1(ph, log10_ens=en), vm(ph, [0.05, 0.1 + (en - 3) * 0.1]))
+ # lct = lctemplate.LCTemplate([p1, p2], [0.4, 0.6])
+ # assert lct.is_energy_dependent()
+ # assert p1.gradient(ph[:10], free=True).shape[0] == 2
+ # assert p1.gradient(ph[:10], free=False).shape[0] == 4
+ # assert lct.check_gradient(quiet=True, seed=0)
+ # assert lct.check_derivative()
+
+ # # check integral
+ # ens = np.linspace(2, 4, 21)
+ # integrals = p1.integrate(0, 0.5, log10_ens=ens)
+ # quads = [
+ # quad(lambda x: p1(x, log10_ens=ens[i]), 0, 0.5)[0] for i in range(len(ens))
+ # ]
+ # assert np.allclose(quads, integrals)
+
+ # # check overflow
+ # p1 = lceprimitives.LCEVonMises()
+ # p1.p[0] = p1.bounds[0][0]
+ # assert not np.isnan(p1(p1.p[-1]))
+
+ # # check fast Bessel function
+ # q = np.logspace(-1, np.log10(700), 201)
+ # I0 = lcprimitives.FastBessel(order=0)
+ # I1 = lcprimitives.FastBessel(order=1)
+ # assert np.all(np.abs(I0(q) / i0(q) - 1) < 1e-2)
+ # assert np.all(np.abs(I1(q) / i1(q) - 1) < 1e-2)
diff --git a/tests/test_tempox_compatibility.py b/tests/test_tempox_compatibility.py
index d76a9bab5..ca3b4b2c1 100644
--- a/tests/test_tempox_compatibility.py
+++ b/tests/test_tempox_compatibility.py
@@ -32,7 +32,7 @@
def test_noise_parameter_aliases(name, want):
parfile = "\n".join([par_basic, f"{name} -f bogus 1.234"])
m = get_model(StringIO(parfile))
- assert getattr(m, want + "1").value == 1.234
+ assert getattr(m, f"{want}1").value == 1.234
assert re.search(
f"^{want}" + r"\s+-f\s+bogus\s+1.23\d+\s*$", m.as_parfile(), re.MULTILINE
)
diff --git a/tests/test_tim_writing.py b/tests/test_tim_writing.py
index dcbea292e..30c8f950c 100644
--- a/tests/test_tim_writing.py
+++ b/tests/test_tim_writing.py
@@ -63,7 +63,7 @@ def test_basic():
],
)
def test_time(c):
- f = StringIO(basic_tim_header + "\n{}\n".format(c) + basic_tim)
+ f = StringIO(f"{basic_tim_header}\n{c}\n{basic_tim}")
do_roundtrip(get_TOAs(f))
diff --git a/tests/test_times.py b/tests/test_times.py
index 11f1c0ae9..4d282a756 100644
--- a/tests/test_times.py
+++ b/tests/test_times.py
@@ -24,7 +24,7 @@ def test_times_against_tempo2():
ls = u.def_unit("ls", const.c * 1.0 * u.s)
log.info("Reading TOAs into PINT")
- ts = toa.get_TOAs(datadir + "/testtimes.tim", include_bipm=False, usepickle=False)
+ ts = toa.get_TOAs(f"{datadir}/testtimes.tim", include_bipm=False, usepickle=False)
if log.level < 25:
ts.print_summary()
ts.table.sort("index")
@@ -40,7 +40,7 @@ def test_times_against_tempo2():
# assert(len(goodlines)==len(ts.table))
# t2result = numpy.genfromtxt('datafile/testtimes.par' + '.tempo2_test', names=True, comments = '#')
- f = open(datadir + "/testtimes.par" + ".tempo2_test")
+ f = open(f"{datadir}/testtimes.par.tempo2_test")
lines = f.readlines()
goodlines = lines[1:]
# Get the output lines from the TOAs
diff --git a/tests/test_timing_model.py b/tests/test_timing_model.py
index e68671539..82541ace2 100644
--- a/tests/test_timing_model.py
+++ b/tests/test_timing_model.py
@@ -46,7 +46,7 @@ def timfile_nojumps():
class TestModelBuilding:
- def setup(self):
+ def setup_method(self):
self.parfile = os.path.join(datadir, "J0437-4715.par")
def test_from_par(self):
diff --git a/tests/test_toa.py b/tests/test_toa.py
index 2c27c9063..20303f174 100644
--- a/tests/test_toa.py
+++ b/tests/test_toa.py
@@ -1,13 +1,12 @@
+import pytest
import io
import os
import re
-import unittest
-import numpy as np
import pytest
+import numpy as np
import astropy.units as u
-import astropy.constants as c
-from astropy.time import Time, TimeDelta
+from astropy.time import Time
from datetime import datetime
from pinttestdata import datadir
@@ -18,38 +17,38 @@
from pint.simulation import make_fake_toas_uniform
-class TestTOA(unittest.TestCase):
+class TestTOA:
"""Test of TOA class"""
- def setUp(self):
+ def setup_method(self):
self.MJD = 57000
def test_units(self):
- with self.assertRaises(u.UnitConversionError):
+ with pytest.raises(u.UnitConversionError):
t = TOA(self.MJD * u.m)
- with self.assertRaises(u.UnitConversionError):
+ with pytest.raises(u.UnitConversionError):
t = TOA((self.MJD * u.m, 0))
t = TOA((self.MJD * u.day).to(u.s))
- with self.assertRaises(u.UnitConversionError):
+ with pytest.raises(u.UnitConversionError):
t = TOA((self.MJD * u.day, 0))
t = TOA((self.MJD * u.day, 0 * u.day))
- with self.assertRaises(u.UnitConversionError):
+ with pytest.raises(u.UnitConversionError):
t = TOA(self.MJD, error=1 * u.m)
t = TOA(self.MJD, freq=100 * u.kHz)
- with self.assertRaises(u.UnitConversionError):
+ with pytest.raises(u.UnitConversionError):
t = TOA(self.MJD, freq=100 * u.s)
def test_precision_mjd(self):
t = TOA(self.MJD)
- self.assertEqual(t.mjd.precision, 9)
+ assert t.mjd.precision == 9
def test_precision_time(self):
t = TOA(Time("2008-08-19", format="iso", precision=1))
- self.assertEqual(t.mjd.precision, 9)
+ assert t.mjd.precision == 9
def test_typo(self):
TOA(self.MJD, errror="1")
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
TOA(self.MJD, errror=1, flags={})
def test_toa_object(self):
@@ -59,21 +58,21 @@ def test_toa_object(self):
obs = "ao"
# scale should be None when MJD is a Time
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
toa = TOA(MJD=toatime, error=toaerr, freq=freq, obs=obs, scale="utc")
# flags should be stored without their leading -
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
toa = TOA(
MJD=toatime, error=toaerr, freq=freq, obs=obs, flags={"-foo": "foo1"}
)
# Invalid flag
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
toa = TOA(
MJD=toatime, error=toaerr, freq=freq, obs=obs, flags={"$": "foo1"}
)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
toa = TOA(MJD=toatime, error=toaerr, freq=freq, obs=obs, flags={"foo": 1})
toa = TOA(MJD=toatime, error=toaerr, freq=freq, obs=obs, foo="foo1")
@@ -82,23 +81,23 @@ def test_toa_object(self):
assert len(toa.flags) > 0
# Missing name
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
toa.as_line()
toa = TOA(MJD=toatime, error=toaerr, freq=freq, obs=obs, foo="foo1", name="bla")
assert "bla" in toa.as_line()
-class TestTOAs(unittest.TestCase):
+class TestTOAs:
"""Test of TOAs class"""
- def setUp(self):
+ def setup_method(self):
self.freq = 1440.012345678 * u.MHz
self.obs = "gbt"
self.MJD = 57000
self.error = 3.0
- def test_make_TOAs(self):
+ def test_make_toas(self):
t = TOA(self.MJD, freq=self.freq, obs=self.obs, error=self.error)
t_list = [t, t]
assert t_list[0].mjd.precision == 9
diff --git a/tests/test_toa_create.py b/tests/test_toa_create.py
index a54cdd712..3c9c63c5d 100644
--- a/tests/test_toa_create.py
+++ b/tests/test_toa_create.py
@@ -1,5 +1,5 @@
import numpy as np
-from astropy import units as u, constants as c
+from astropy import units as u
from pint import pulsar_mjd
from pint import toa
from pint.models import get_model
@@ -53,6 +53,47 @@ def test_toas_compare(t, errors, freqs):
flags = [{"type": "good"}, {"type": "bad"}]
toas = toa.get_TOAs_array(t, obs, errors=errors, freqs=freqs, flags=flags, **kwargs)
+ combined_flags = copy.deepcopy(flags)
+ for flag in combined_flags:
+ for k, v in kwargs.items():
+ flag[k] = v
+ errors = np.atleast_1d(errors)
+ freqs = np.atleast_1d(freqs)
+ if len(errors) < len(t):
+ errors = np.repeat(errors, len(t))
+ if len(freqs) < len(t):
+ freqs = np.repeat(freqs, len(t))
+ toalist = (
+ [
+ toa.TOA((tt0, tt1), obs=obs, error=e, freq=fr, flags=f)
+ for tt0, tt1, e, f, fr in zip(t[0], t[1], errors, combined_flags, freqs)
+ ]
+ if isinstance(t, tuple)
+ else [
+ toa.TOA(tt, obs=obs, error=e, freq=fr, flags=f)
+ for tt, e, f, fr in zip(t, errors, combined_flags, freqs)
+ ]
+ )
+ toas_fromlist = toa.get_TOAs_list(toalist)
+ assert np.all(toas.table == toas_fromlist.table)
+
+
+@pytest.mark.parametrize(
+ "t",
+ [
+ pulsar_mjd.Time(np.array([55000, 56000]), scale="utc", format="pulsar_mjd"),
+ np.array([55000, 56000]),
+ (np.array([55000, 56000]), np.array([0, 0])),
+ ],
+)
+@pytest.mark.parametrize("errors", [1, 1.0 * u.us, np.array([1.0, 2.3]) * u.us])
+@pytest.mark.parametrize("freqs", [1400, 400 * u.MHz, np.array([1400, 1500]) * u.MHz])
+def test_toas_tolist(t, errors, freqs):
+ obs = "gbt"
+ kwargs = {"name": "data"}
+ flags = [{"type": "good"}, {"type": "bad"}]
+ toas = toa.get_TOAs_array(t, obs, errors=errors, freqs=freqs, flags=flags, **kwargs)
+
combined_flags = copy.deepcopy(flags)
for flag in combined_flags:
for k, v in kwargs.items():
@@ -73,8 +114,13 @@ def test_toas_compare(t, errors, freqs):
toa.TOA((tt0, tt1), obs=obs, error=e, freq=fr, flags=f)
for tt0, tt1, e, f, fr in zip(t[0], t[1], errors, combined_flags, freqs)
]
- toas_fromlist = toa.get_TOAs_list(toalist)
- assert np.all(toas.table == toas_fromlist.table)
+ toas_tolist = toas.to_TOA_list()
+ # for i in range(len(toalist)):
+ # assert toalist[i] == toas_tolist[i], f"{toalist[i]} != {toas_tolist[i]}"
+ # depending on precision they should be equal, but if they aren't then just check the MJDs
+ assert np.all([t1 == t2 for (t1, t2) in zip(toalist, toas_tolist)]) or np.allclose(
+ [x.mjd.mjd for x in toalist], [x.mjd.mjd for x in toas_tolist]
+ )
def test_kwargs_as_flags():
diff --git a/tests/test_toa_flag.py b/tests/test_toa_flag.py
index 9126f94c3..563578cbe 100644
--- a/tests/test_toa_flag.py
+++ b/tests/test_toa_flag.py
@@ -2,7 +2,7 @@
import os
-import unittest
+import pytest
import io
import numpy as np
import pytest
@@ -11,11 +11,11 @@
from pinttestdata import datadir
-class TestToaFlag(unittest.TestCase):
+class TestToaFlag:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
cls.tim = "B1855+09_NANOGrav_dfg+12.tim"
cls.toas = toa.get_TOAs(
os.path.join(datadir, cls.tim),
@@ -100,7 +100,7 @@ def test_flag_set_partialarray():
if i < 2:
assert float(t["test", i]) == 1
else:
- assert not "test" in t["flags"][i]
+ assert "test" not in t["flags"][i]
def test_flag_delete():
@@ -108,7 +108,7 @@ def test_flag_delete():
t["test"] = "1"
t["test"] = ""
for i in range(t.ntoas):
- assert not "test" in t["flags"][i]
+ assert "test" not in t["flags"][i]
def test_flag_partialdelete():
@@ -117,6 +117,6 @@ def test_flag_partialdelete():
t[:2, "test"] = ""
for i in range(t.ntoas):
if i < 2:
- assert not "test" in t["flags"][i]
+ assert "test" not in t["flags"][i]
else:
assert float(t["test", i]) == 1
diff --git a/tests/test_toa_pickle.py b/tests/test_toa_pickle.py
index a12a740ca..d7dc1cfee 100644
--- a/tests/test_toa_pickle.py
+++ b/tests/test_toa_pickle.py
@@ -1,8 +1,8 @@
-#!/usr/bin/env python
+import contextlib
import os
import shutil
import time
-import unittest
+import pytest
import copy
import pytest
@@ -19,16 +19,14 @@ def temp_tim(tmpdir):
return tt, tp
-class TestTOAReader(unittest.TestCase):
- def setUp(self):
+class TestTOAReader:
+ def setup_method(self):
os.chdir(datadir)
# First, read the TOAs from the tim file.
# This should also create the pickle file.
- try:
+ with contextlib.suppress(OSError):
os.remove("test1.tim.pickle.gz")
os.remove("test1.tim.pickle")
- except OSError:
- pass
tt = toa.get_TOAs("test1.tim", usepickle=False, include_bipm=False)
self.numtoas = tt.ntoas
del tt
@@ -87,10 +85,8 @@ def test_pickle_changed_planets(temp_tim):
)
def test_pickle_invalidated_settings(temp_tim, k, v, wv):
tt, tp = temp_tim
- d = {}
- d[k] = v
- wd = {}
- wd[k] = wv
+ d = {k: v}
+ wd = {k: wv}
toa.get_TOAs(tt, usepickle=True, **d)
assert toa.get_TOAs(tt, usepickle=True, **wd).clock_corr_info[k] == wv
diff --git a/tests/test_toa_reader.py b/tests/test_toa_reader.py
index 4d281187b..77c7e04f2 100644
--- a/tests/test_toa_reader.py
+++ b/tests/test_toa_reader.py
@@ -1,6 +1,6 @@
import os
import shutil
-import unittest
+import pytest
from io import StringIO
from pathlib import Path
@@ -61,8 +61,8 @@ def check_indices_contiguous(toas):
assert ix[0] == 0
-class TestTOAReader(unittest.TestCase):
- def setUp(self):
+class TestTOAReader:
+ def setup_method(self):
self.x = toa.TOAs(datadir / "test1.tim")
self.x.apply_clock_corrections()
self.x.compute_TDBs()
@@ -185,7 +185,7 @@ def test_toa_merge():
datadir / "parkes.toa",
]
toas = [toa.get_TOAs(ff) for ff in filenames]
- ntoas = sum([tt.ntoas for tt in toas])
+ ntoas = sum(tt.ntoas for tt in toas)
nt = toa.merge_TOAs(toas)
assert len(nt.observatories) == 3
assert nt.table.meta["filename"] == filenames
@@ -201,7 +201,7 @@ def test_toa_merge_again():
datadir / "parkes.toa",
]
toas = [toa.get_TOAs(ff) for ff in filenames]
- ntoas = sum([tt.ntoas for tt in toas])
+ ntoas = sum(tt.ntoas for tt in toas)
nt = toa.merge_TOAs(toas)
# The following tests merging with and already merged TOAs
other = toa.get_TOAs(datadir / "test1.tim")
@@ -218,7 +218,7 @@ def test_toa_merge_again_2():
datadir / "parkes.toa",
]
toas = [toa.get_TOAs(ff) for ff in filenames]
- ntoas = sum([tt.ntoas for tt in toas])
+ ntoas = sum(tt.ntoas for tt in toas)
other = toa.get_TOAs(datadir / "test1.tim")
# check consecutive merging
nt = toa.merge_TOAs(toas[:2])
diff --git a/tests/test_toa_selection.py b/tests/test_toa_selection.py
index f96424253..47a05eadc 100644
--- a/tests/test_toa_selection.py
+++ b/tests/test_toa_selection.py
@@ -1,7 +1,7 @@
import copy
import logging
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -15,9 +15,9 @@
from pinttestdata import datadir
-class TestTOAselection(unittest.TestCase):
+class TestTOAselection:
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.parf = "B1855+09_NANOGrav_9yv1.gls.par"
cls.timf = "B1855+09_NANOGrav_9yv1.tim"
@@ -42,7 +42,7 @@ def add_dmx_section(self, toas_table):
toas_table["mjd_float"] >= r1.mjd, toas_table["mjd_float"] <= r2.mjd
)
toas_table["DMX_section"][msk] = epoch_ind
- epoch_ind = epoch_ind + 1
+ epoch_ind += 1
return toas_table
def get_dmx_old(self, toas):
@@ -76,7 +76,7 @@ def test_boolean_selection(self):
self.toas.unselect()
assert self.toas.ntoas == 4005
- def test_DMX_selection(self):
+ def test_dmx_selection(self):
dmx_old = self.get_dmx_old(self.toas).value
# New way in the code.
dmx_new = self.model.dmx_dm(self.toas).value
@@ -131,7 +131,3 @@ def test_change_condition(self):
assert len(indx0005_2) == 0
assert len(run1) == len(run2)
assert np.allclose(run1, run2)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_troposphere_model.py b/tests/test_troposphere_model.py
index 6d32f2715..9aa0cfeb6 100644
--- a/tests/test_troposphere_model.py
+++ b/tests/test_troposphere_model.py
@@ -1,7 +1,5 @@
-#!/usr/bin/env python
-
import os
-import unittest
+import pytest
import astropy.units as u
import numpy as np
@@ -11,12 +9,12 @@
from pinttestdata import datadir
-class TestTroposphereDelay(unittest.TestCase):
+class TestTroposphereDelay:
MIN_ALT = 5 # the minimum altitude in degrees for testing the delay model
FLOAT_THRESHOLD = 1e-12 #
- def setUp(self):
+ def setup_method(self):
# parfile = os.path.join(datadir, "J1744-1134.basic.par")
# ngc = os.path.join(datadir, "NGC6440E")
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d0b0176b0..18988b4d0 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -53,6 +53,9 @@
find_prefix_bytime,
merge_dmx,
convert_dispersion_measure,
+ print_color_examples,
+ parse_time,
+ info_string,
)
@@ -843,3 +846,28 @@ def test_convert_dm():
dm_codata = convert_dispersion_measure(dm)
assert np.isfinite(dm_codata)
+
+
+def test_print_color_examples():
+ print_color_examples()
+
+
+@pytest.mark.parametrize(
+ "t",
+ [
+ Time(55555, format="pulsar_mjd", scale="tdb", precision=9),
+ 55555 * u.d,
+ 55555.0,
+ 55555,
+ "55555",
+ ],
+)
+def test_parse_time(t):
+ assert parse_time(t, scale="tdb") == Time(
+ 55555, format="pulsar_mjd", scale="tdb", precision=9
+ )
+
+
+def test_info_str():
+ info = info_string()
+ dinfo = info_string(detailed=True)
diff --git a/tests/test_widebandTOA_fitting.py b/tests/test_widebandTOA_fitting.py
index 99c740b5a..dc8e69787 100644
--- a/tests/test_widebandTOA_fitting.py
+++ b/tests/test_widebandTOA_fitting.py
@@ -1,5 +1,4 @@
-""" Various of tests for the general data fitter using wideband TOAs.
-"""
+""" Various of tests for the general data fitter using wideband TOAs."""
import os
@@ -14,7 +13,7 @@
class TestWidebandTOAFitter:
- def setup(self):
+ def setup_method(self):
self.model = get_model("J1614-2230_NANOGrav_12yv3.wb.gls.par")
self.toas = get_TOAs(
"J1614-2230_NANOGrav_12yv3.wb.tim", ephem="DE436", bipm_version="BIPM2015"
@@ -56,7 +55,7 @@ def test_fitting_no_full_cov(self):
prefit_pint = fitter.resids_init.toa.time_resids
prefit_tempo = self.tempo_res[:, 0] * u.us
diff_prefit = (prefit_pint - prefit_tempo).to(u.ns)
- # 50 ns is the difference of PINT tempo precession and nautation model.
+ # 50 ns is the difference of PINT tempo precession and nutation model.
assert np.abs(diff_prefit - diff_prefit.mean()).max() < 50 * u.ns
postfit_pint = fitter.resids.toa.time_resids
postfit_tempo = self.tempo_res[:, 1] * u.us
diff --git a/tests/test_wideband_dm_data.py b/tests/test_wideband_dm_data.py
index c48b13729..c62d5d9f3 100644
--- a/tests/test_wideband_dm_data.py
+++ b/tests/test_wideband_dm_data.py
@@ -1,5 +1,4 @@
-""" Various of tests on the wideband DM data
-"""
+""" Various of tests on the wideband DM data"""
import io
import os
@@ -88,7 +87,7 @@ def wb_toas_all(wb_model):
class TestDMData:
- def setup(self):
+ def setup_method(self):
self.model = get_model("J1614-2230_NANOGrav_12yv3.wb.gls.par")
self.toas = get_TOAs("J1614-2230_NANOGrav_12yv3.wb.tim")
toa_backends, valid_flags = self.toas.get_flag_value("fe")
diff --git a/tests/test_wls_fitter.py b/tests/test_wls_fitter.py
index 0a2f0b64d..1c9532e7a 100644
--- a/tests/test_wls_fitter.py
+++ b/tests/test_wls_fitter.py
@@ -1,6 +1,5 @@
-#! /usr/bin/env python
import os
-import unittest
+import pytest
from pint.models.model_builder import get_model
from pint import toa
@@ -8,11 +7,11 @@
from pinttestdata import datadir
-class Testwls(unittest.TestCase):
+class Testwls:
"""Compare delays from the dd model with tempo and PINT"""
@classmethod
- def setUpClass(cls):
+ def setup_class(cls):
os.chdir(datadir)
cls.par = "B1855+09_NANOGrav_dfg+12_TAI_FB90.par"
cls.tim = "B1855+09_NANOGrav_dfg+12.tim"
diff --git a/tests/test_zima.py b/tests/test_zima.py
index 0d6260423..d05664669 100644
--- a/tests/test_zima.py
+++ b/tests/test_zima.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env python
import os
import sys
from io import StringIO
@@ -7,6 +6,7 @@
import pytest
from pinttestdata import datadir
+import matplotlib
import pint.scripts.zima as zima
from pint.models import get_model_and_toas
from pint.residuals import Residuals
@@ -67,8 +67,6 @@ def test_wb_result_with_noise(tmp_path):
def test_zima_plot(tmp_path):
- import matplotlib
-
matplotlib.use("Agg")
parfile = os.path.join(datadir, "NGC6440E.par")
@@ -83,8 +81,6 @@ def test_zima_plot(tmp_path):
def test_zima_fuzzdays(tmp_path):
- import matplotlib
-
matplotlib.use("Agg")
parfile = os.path.join(datadir, "NGC6440E.par")
diff --git a/tests/utils.py b/tests/utils.py
index 9b028b2b3..005c197bb 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -45,12 +45,11 @@ def verify_stand_alone_binary_parameter_updates(m):
if pint_par_name is None:
continue
pint_par = getattr(m, pint_par_name)
- if pint_par.value is not None:
- if not isinstance(pint_par, funcParameter):
- if hasattr(standalone_par, "value"):
- # Test for astropy quantity
- assert pint_par.value == standalone_par.value
- else:
- # Test for non-astropy quantity parameters.
- assert pint_par.value == standalone_par
+ if pint_par.value is not None and not isinstance(pint_par, funcParameter):
+ if hasattr(standalone_par, "value"):
+ # Test for astropy quantity
+ assert pint_par.value == standalone_par.value
+ else:
+ # Test for non-astropy quantity parameters.
+ assert pint_par.value == standalone_par
return
diff --git a/tox.ini b/tox.ini
index 93cfe9d3a..c996f91f6 100644
--- a/tox.ini
+++ b/tox.ini
@@ -67,7 +67,7 @@ deps =
scipy==1.4.1
pytest
coverage
- hypothesis
+ hypothesis<=6.72.0
commands = {posargs:pytest}
[testenv:report]
diff --git a/validation/wls_vs_gls.py b/validation/wls_vs_gls.py
index d926ab7e4..7c4ca3317 100644
--- a/validation/wls_vs_gls.py
+++ b/validation/wls_vs_gls.py
@@ -26,7 +26,7 @@ def change_parfile(filename, param, value):
with open(abs_path, "w") as new_file:
with open(filename) as old_file:
for line in old_file:
- if line.startswith(param + " "):
+ if line.startswith(f"{param} "):
l = line.split()
line = line.replace(l[1], strv)
line = line.replace(" 0 ", " 1 ")
@@ -48,7 +48,7 @@ def perturb_param(f, param, h, source, target):
"""Perturbate parameter value and change the corresponding par file"""
reset_model(source, target, f)
pn = f.model.match_param_aliases(param)
- if not pn == "":
+ if pn != "":
par = getattr(f.model, pn)
orv = par.value
par.value = (1 + h) * orv
@@ -61,7 +61,7 @@ def perturb_param(f, param, h, source, target):
def check_tempo_output(parf, timf, result_par):
"""Check out tempo output"""
- a = subprocess.check_output("tempo -f " + parf + " " + timf, shell=True)
+ a = subprocess.check_output(f"tempo -f {parf} {timf}", shell=True)
tempo_m = mb.get_model(result_par)
info_idx = a.index("Weighted RMS residual: pre-fit")
res = a[info_idx:-1]
@@ -69,13 +69,13 @@ def check_tempo_output(parf, timf, result_par):
mpost = re.search("post-fit(.+?)us", res)
mchi = re.search("=(.+?)pre/post", res)
try:
- pre_rms = float(mpre.group(1))
- post_rms = float(mpost.group(1))
- chi = float(mchi.group(1))
+ pre_rms = float(mpre[1])
+ post_rms = float(mpost[1])
+ chi = float(mchi[1])
except ValueError:
- pre_rms = mpre.group(1)
- post_rms = mpost.group(1)
- chi = mchi.group(1)
+ pre_rms = mpre[1]
+ post_rms = mpost[1]
+ chi = mchi[1]
if chi.startswith("**"):
chi = 0.0
return tempo_m, pre_rms, post_rms, chi
@@ -84,25 +84,25 @@ def check_tempo_output(parf, timf, result_par):
def check_tempo2_output(parf, timf, p, result_par):
"""Check out tempo2 output"""
res = subprocess.check_output(
- "tempo2 -f " + parf + " " + timf + " -norescale -newpar", shell=True
+ f"tempo2 -f {parf} {timf} -norescale -newpar", shell=True
)
mpre = re.search("RMS pre-fit residual =(.+?)(us)", res)
mpost = re.search("RMS post-fit residual =(.+?)(us)", res)
mchi = re.search("Chisqr/nfree =(.+?)/", res)
m = mb.get_model(result_par)
- mpostn = re.findall(r"\d+\.\d+", mpost.group(1))
+ mpostn = re.findall(r"\d+\.\d+", mpost[1])
try:
- pre_rms = float(mpre.group(1))
+ pre_rms = float(mpre[1])
except ValueError:
- pre_rms = mpre.group(1)
+ pre_rms = mpre[1]
try:
post_rms = float(mpostn[0])
except ValueError:
- post_rms = mpost.group(1)
+ post_rms = mpost[1]
try:
- chi = float(mchi.group(1)) / len(t.table)
- except:
- chi = mchi.group(1)
+ chi = float(mchi[1]) / len(t.table)
+ except Exception:
+ chi = mchi[1]
pv = getattr(m, p).value
pu = getattr(m, p).uncertainty_value
return pv, pu, post_rms, chi
@@ -116,9 +116,8 @@ def check_tempo2_output(parf, timf, p, result_par):
base = []
for fn in parfiles:
b = fn.replace("_ori.par", "")
- if not b.endswith(".par"):
- if b not in base:
- base.append(b.replace(".gls", ""))
+ if not b.endswith(".par") and b not in base:
+ base.append(b.replace(".gls", ""))
per_step = {
"A1": 1e-05,
@@ -148,14 +147,14 @@ def check_tempo2_output(parf, timf, p, result_par):
# if not b_name.startswith('B1855'):
# continue
if b_name.endswith("+12"):
- par = b_name + "_ori.par"
- tempo_par = b_name + "_ptb.par"
+ par = f"{b_name}_ori.par"
+ tempo_par = f"{b_name}_ptb.par"
else:
- par = b_name + ".gls_ori.par"
- tempo_par = b_name + ".gls_ptb.par"
- tim = b_name + ".tim"
+ par = f"{b_name}.gls_ori.par"
+ tempo_par = f"{b_name}.gls_ptb.par"
+ tim = f"{b_name}.tim"
m = mb.get_model(par)
- result_par = m.PSR.value + ".par"
+ result_par = f"{m.PSR.value}.par"
res[b_name] = {}
cmd = subprocess.list2cmdline(["grep", "'EPHEM'", par])
out = subprocess.check_output(cmd, shell=True)
@@ -323,11 +322,11 @@ def check_tempo2_output(parf, timf, p, result_par):
)
# res[b_name] = (pt_table, pt2_table)
- out_name_pt = b_name + "_PT.html"
- print("Write " + out_name_pt)
+ out_name_pt = f"{b_name}_PT.html"
+ print(f"Write {out_name_pt}")
pt_table.write(out_name_pt, format="ascii.html", overwrite=True)
- out_name_pt2 = b_name + "_PT2.html"
- print("Write " + out_name_pt2)
+ out_name_pt2 = f"{b_name}_PT2.html"
+ print(f"Write {out_name_pt2}")
pt2_table.write(out_name_pt2, format="ascii.html", overwrite=True)
reset_model(par, tempo_par, f)
# subprocess.call("cat " + out_name_pt + ">>" + out_name_pt, shell=True)