Skip to content

Commit cfe934c

Browse files
committed
Fix some doc build warnings
1 parent f4b5ff9 commit cfe934c

File tree

6 files changed

+18
-9
lines changed

6 files changed

+18
-9
lines changed

docs/CHANGELOG.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ jax 0.2.9 (Unreleased)
1717

1818
* Breaking changes:
1919

20-
* ``jax.ops.segment_sum` now drops segment IDs that are out of range rather
20+
* :func:`jax.ops.segment_sum` now drops segment IDs that are out of range rather
2121
than wrapping them into the segment ID space. This was done for performance
2222
reasons.
2323

2424
* `GitHub commits <https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9>`__.
2525

2626
* New features:
2727

28-
* Extend the `jax.experimental.loops` module with support for pytrees. Improved
28+
* Extend the :mod:`jax.experimental.loops` module with support for pytrees. Improved
2929
error checking and error messages.
3030

3131
* Add :func:`jax.experimental.enable_x64` and :func:`jax.experimental.disable_x64`.
@@ -78,7 +78,7 @@ jax 0.2.8 (January 12 2021)
7878

7979
* New features:
8080

81-
* New flag for debugging ``inf``s, analagous to that for ``NaN``s (`#5224 <https://github.com/google/jax/pull/5224>`_).
81+
* New flag for debugging ``inf``, analagous to that for ``NaN`` (`#5224 <https://github.com/google/jax/pull/5224>`_).
8282

8383
jax 0.2.7 (Dec 4 2020)
8484
----------------------
@@ -437,7 +437,7 @@ jax 0.1.67 (May 12, 2020)
437437

438438
* Notable changes:
439439

440-
* The visibility of names exported from :py:module:`jax.numpy` has been
440+
* The visibility of names exported from :mod:`jax.numpy` has been
441441
tightened. This may break code that was making use of names that were
442442
previously exported accidentally.
443443

@@ -585,7 +585,7 @@ jax 0.1.60 (March 17, 2020)
585585
``static_argnums`` in :py:func:`jax.jit`.
586586
* Improved error messages for when tracers are mistakenly saved in global state.
587587
* Added :py:func:`jax.nn.one_hot` utility function.
588-
* Added :py:module:`jax.experimental.jet` for exponentially faster
588+
* Added :mod:`jax.experimental.jet` for exponentially faster
589589
higher-order automatic differentiation.
590590
* Added more correctness checking to arguments of :py:func:`jax.lax.broadcast_in_dim`.
591591

docs/conf.py

+2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@
109109
'notebooks/XLA_in_Python.ipynb',
110110
# Sometimes sphinx reads its own outputs as inputs!
111111
'build/html',
112+
'notebooks/README.md',
113+
'README.md',
112114
]
113115

114116
# The name of the Pygments (syntax highlighting) style to use.

docs/developer.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ to set up a CUDA environment.
8282
You can either install Python using its
8383
`Windows installer <https://www.python.org/downloads/>`_, or if you prefer, you
8484
can use `Anaconda <https://docs.anaconda.com/anaconda/install/windows/>`_
85-
or `Miniconda <https://docs.conda.io/en/latest/miniconda.html#windows-installers>`_
85+
or `Miniconda <https://docs.conda.io/en/latest/miniconda.html#windows-installers>`__
8686
to setup a Python environment.
8787

8888
Some targets of Bazel use bash utilities to do scripting, so `MSYS2 <https://www.msys2.org>`_
@@ -186,7 +186,7 @@ To rebuild the documentation, install several packages::
186186

187187
You must also install ``pandoc`` in order to regenerate the notebooks.
188188
See `Install Pandoc <https://pandoc.org/installing.html>`_,
189-
or using `Miniconda <https://docs.conda.io/en/latest/miniconda.html>`_ which
189+
or using `Miniconda <https://docs.conda.io/en/latest/miniconda.html>`__ which
190190
I have used successfully on the Mac: ``conda install -c conda-forge pandoc``.
191191
If you do not want to install ``pandoc`` then you should regenerate the documentation
192192
without the notebooks.

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ For an introduction to JAX, start at the
4242
pytrees
4343
rank_promotion_warning
4444
type_promotion
45+
custom_vjp_update
4546

4647
.. toctree::
4748
:maxdepth: 2

jax/_src/ops/scatter.py

+5
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def index_add(x, idx, y, indices_are_sorted=False, unique_indices=False):
107107
108108
Returns the value of `x` that would result from the
109109
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
110+
110111
x[idx] += y
111112
112113
Note the `index_add` operator is pure; `x` itself is
@@ -149,6 +150,7 @@ def index_mul(x, idx, y, indices_are_sorted=False, unique_indices=False):
149150
150151
Returns the value of `x` that would result from the
151152
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
153+
152154
x[idx] *= y
153155
154156
Note the `index_mul` operator is pure; `x` itself is
@@ -191,6 +193,7 @@ def index_min(x, idx, y, indices_are_sorted=False, unique_indices=False):
191193
192194
Returns the value of `x` that would result from the
193195
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
196+
194197
x[idx] = minimum(x[idx], y)
195198
196199
Note the `index_min` operator is pure; `x` itself is
@@ -230,6 +233,7 @@ def index_max(x, idx, y, indices_are_sorted=False, unique_indices=False):
230233
231234
Returns the value of `x` that would result from the
232235
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
236+
233237
x[idx] = maximum(x[idx], y)
234238
235239
Note the `index_max` operator is pure; `x` itself is
@@ -269,6 +273,7 @@ def index_update(x, idx, y, indices_are_sorted=False, unique_indices=False):
269273
270274
Returns the value of `x` that would result from the
271275
NumPy-style :mod:`indexed assignment <numpy.doc.indexing>`::
276+
272277
x[idx] = y
273278
274279
Note the `index_update` operator is pure; `x` itself is

jax/_src/scipy/linalg.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,9 @@ def triu(m, k=0):
245245
246246
The number of required squarings = max(0, ceil(log2(norm(A)) - c)
247247
where norm() denotes the L1 norm, and
248-
c=2.42 for float64 or complex128,
249-
c=1.97 for float32 or complex64
248+
249+
- c=2.42 for float64 or complex128,
250+
- c=1.97 for float32 or complex64
250251
""")
251252

252253
@_wraps(scipy.linalg.expm, lax_description=_expm_description)

0 commit comments

Comments
 (0)