diff --git a/Makefile b/Makefile index 9cb0a920b4..ee0b6945f8 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,9 @@ scrub: FORCE find tutorial -name "*.ipynb" | xargs python tutorial/source/cleannb.py doctest: FORCE - python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro + # We skip testing pyro.distributions.torch wrapper classes because + # they include torch docstrings which are tested upstream. + python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py perf-test: FORCE bash scripts/perf_test.sh ${ref} diff --git a/docs/source/conf.py b/docs/source/conf.py index 806852abb8..e9c4bf84b8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,6 +47,7 @@ "sphinx.ext.graphviz", # "sphinx.ext.autodoc", "sphinx.ext.doctest", + 'sphinx.ext.napoleon', ] # Disable documentation inheritance so as to avoid inheriting diff --git a/pyro/contrib/forecast/forecaster.py b/pyro/contrib/forecast/forecaster.py index 505660bb08..14e9226d89 100644 --- a/pyro/contrib/forecast/forecaster.py +++ b/pyro/contrib/forecast/forecaster.py @@ -70,6 +70,8 @@ def model(self, zero_data, covariates): @property def time_plate(self): """ + Helper to create a ``pyro.plate`` over time. + :returns: A plate named "time" with size ``covariates.size(-2)`` and ``dim=-1``. This is available only during model execution. :rtype: :class:`~pyro.plate` diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index c6e8c7e1d2..805b1b83b6 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import math +import re +import textwrap import torch @@ -334,6 +336,16 @@ def support(self): return constraints.interval(self._unbroadcasted_low, self._unbroadcasted_high) +def _cat_docstrings(*docstrings): + result = "\n".join(textwrap.dedent(s.lstrip("\n")) for s in docstrings) + result = re.sub("\n\n+", "\n\n", result) + # Drop torch-specific lines. + result = "".join( + line for line in result.splitlines(keepends=True) if "xdoctest" not in line + ) + return result + + # Programmatically load all distributions from PyTorch. __all__ = [] for _name, _Dist in torch.distributions.__dict__.items(): @@ -354,10 +366,11 @@ def support(self): _PyroDist.__doc__ = """ Wraps :class:`{}.{}` with :class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`. + """.format( _Dist.__module__, _Dist.__name__ ) - + _PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__) __all__.append(_name)