Skip to content

Commit

Permalink
Add/update JAX Advanced Tutorials docs, ToC structure
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Sep 20, 2024
1 parent 6b93b35 commit 0cf040c
Show file tree
Hide file tree
Showing 20 changed files with 20 additions and 4,426 deletions.
5 changes: 0 additions & 5 deletions docs/_tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ JAX 201
:maxdepth: 1

parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance

JAX 301
Expand All @@ -50,6 +47,4 @@ JAX 301
.. toctree::
:maxdepth: 1

jax-primitives
jaxpr
advanced-compilation
File renamed without changes.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,4 +360,6 @@ def linkcode_resolve(domain, info):
'jax-101/07-state.md': 'stateful-computations.md',
'jax-101/08-pjit.rst': 'sharded-computation.md',
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
}
2 changes: 0 additions & 2 deletions docs/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ that use or interface with JAX.
:caption: Extensible JAX internals
:maxdepth: 1

notebooks/How_JAX_primitives_work
jaxpr
notebooks/Writing_custom_interpreters_in_Jax
Custom_Operation_for_GPUs
jax.extend
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
"We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/ffi.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
```

We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:

```{code-cell} ipython3
jax.make_jaxpr(jax.vmap(rms_norm))(x)
Expand Down
2 changes: 1 addition & 1 deletion docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Glossary of terms
jaxpr
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
See :ref:`understanding-jaxprs` for more discussion and examples.
See :ref:`jax-internals-jaxpr` for more discussion and examples.

JIT
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
Expand Down
File renamed without changes.
7 changes: 4 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,21 @@ maintains an up-to-date list.

installation
quickstart
notebooks/Common_Gotchas_in_JAX
faq

.. toctree::
:hidden:
:maxdepth: 1

tutorials

notebooks/Common_Gotchas_in_JAX

faq

.. toctree::
:hidden:
:maxdepth: 2
:caption: Resources
:caption: More guides/resources

user_guides
advanced_guide
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 0cf040c

Please sign in to comment.