From 13f554eb280da859469ed8c05d2b9197b370f639 Mon Sep 17 00:00:00 2001 From: rtviii Date: Thu, 20 Jul 2023 15:26:53 +0200 Subject: [PATCH 1/3] grabbing docstring from torch modules directly, build still failes with indent errors --- docs/source/conf.py | 1 + pyro/distributions/torch.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 806852abb8..2d530bbfc9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -38,6 +38,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + 'sphinxcontrib.napoleon', "sphinx.ext.intersphinx", # "sphinx.ext.todo", # "sphinx.ext.mathjax", # diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index c6e8c7e1d2..6c0e678e93 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -334,6 +334,13 @@ def support(self): return constraints.interval(self._unbroadcasted_low, self._unbroadcasted_high) +def doctest_disable(docstring): + _ = "" + for line in docstring.splitlines(): + _ += line + "#doctest: +DISABLE" + return _ + + # Programmatically load all distributions from PyTorch. __all__ = [] for _name, _Dist in torch.distributions.__dict__.items(): @@ -346,18 +353,25 @@ def support(self): try: _PyroDist = locals()[_name] + torchDistDocstring = _Dist.__doc__ + except KeyError: _PyroDist = type(_name, (_Dist, TorchDistributionMixin), {}) _PyroDist.__module__ = __name__ locals()[_name] = _PyroDist + torchDistDocstring = None _PyroDist.__doc__ = """ Wraps :class:`{}.{}` with :class:`~pyro.distributions.torch_distribution.TorchDistributionMixin`. + """.format( _Dist.__module__, _Dist.__name__ + ) + ( + "\n\n" + doctest_disable(torchDistDocstring) + if torchDistDocstring is not None + else "" ) - __all__.append(_name) From 0905aedf4b9f7a44ddfc38627103ecf7f628a966 Mon Sep 17 00:00:00 2001 From: rtviii Date: Thu, 20 Jul 2023 15:30:52 +0200 Subject: [PATCH 2/3] replaced outdated sphinxcontrib.napoleon with extension sphinx.ext.napoleon --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 2d530bbfc9..e9c4bf84b8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -38,7 +38,6 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinxcontrib.napoleon', "sphinx.ext.intersphinx", # "sphinx.ext.todo", # "sphinx.ext.mathjax", # @@ -48,6 +47,7 @@ "sphinx.ext.graphviz", # "sphinx.ext.autodoc", "sphinx.ext.doctest", + 'sphinx.ext.napoleon', ] # Disable documentation inheritance so as to avoid inheriting From 534195b28c113ae2af6795dfeca9610983650a2e Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Mon, 24 Jul 2023 15:06:19 -0700 Subject: [PATCH 3/3] Fix docs --- Makefile | 4 +++- pyro/contrib/forecast/forecaster.py | 2 ++ pyro/distributions/torch.py | 23 +++++++++++------------ 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 9cb0a920b4..ee0b6945f8 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,9 @@ scrub: FORCE find tutorial -name "*.ipynb" | xargs python tutorial/source/cleannb.py doctest: FORCE - python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro + # We skip testing pyro.distributions.torch wrapper classes because + # they include torch docstrings which are tested upstream. + python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py perf-test: FORCE bash scripts/perf_test.sh ${ref} diff --git a/pyro/contrib/forecast/forecaster.py b/pyro/contrib/forecast/forecaster.py index 505660bb08..14e9226d89 100644 --- a/pyro/contrib/forecast/forecaster.py +++ b/pyro/contrib/forecast/forecaster.py @@ -70,6 +70,8 @@ def model(self, zero_data, covariates): @property def time_plate(self): """ + Helper to create a ``pyro.plate`` over time. + :returns: A plate named "time" with size ``covariates.size(-2)`` and ``dim=-1``. This is available only during model execution. :rtype: :class:`~pyro.plate` diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 6c0e678e93..805b1b83b6 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import math +import re +import textwrap import torch @@ -334,11 +336,14 @@ def support(self): return constraints.interval(self._unbroadcasted_low, self._unbroadcasted_high) -def doctest_disable(docstring): - _ = "" - for line in docstring.splitlines(): - _ += line + "#doctest: +DISABLE" - return _ +def _cat_docstrings(*docstrings): + result = "\n".join(textwrap.dedent(s.lstrip("\n")) for s in docstrings) + result = re.sub("\n\n+", "\n\n", result) + # Drop torch-specific lines. + result = "".join( + line for line in result.splitlines(keepends=True) if "xdoctest" not in line + ) + return result # Programmatically load all distributions from PyTorch. @@ -353,13 +358,10 @@ def doctest_disable(docstring): try: _PyroDist = locals()[_name] - torchDistDocstring = _Dist.__doc__ - except KeyError: _PyroDist = type(_name, (_Dist, TorchDistributionMixin), {}) _PyroDist.__module__ = __name__ locals()[_name] = _PyroDist - torchDistDocstring = None _PyroDist.__doc__ = """ Wraps :class:`{}.{}` with @@ -367,11 +369,8 @@ def doctest_disable(docstring): """.format( _Dist.__module__, _Dist.__name__ - ) + ( - "\n\n" + doctest_disable(torchDistDocstring) - if torchDistDocstring is not None - else "" ) + _PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__) __all__.append(_name)