Skip to content

Commit

Permalink
Add/update JAX Advanced Tutorials docs, ToC structure
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Sep 20, 2024
1 parent a533635 commit 2457601
Show file tree
Hide file tree
Showing 20 changed files with 85 additions and 4,437 deletions.
5 changes: 0 additions & 5 deletions docs/_tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ JAX 201
:maxdepth: 1

parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance

JAX 301
Expand All @@ -50,6 +47,4 @@ JAX 301
.. toctree::
:maxdepth: 1

jax-primitives
jaxpr
advanced-compilation
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def meta_loss_fn(params, data):
meta_grads = jax.grad(meta_loss_fn)(params, data)
```

(stopping-gradients)=

### Stopping gradients

Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph.
Expand Down Expand Up @@ -571,7 +571,7 @@ print("Naive full Hessian materialization")

### Jacobian-Matrix and Matrix-Jacobian products

Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:
Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products:

```{code-cell}
# Isolate the function from the weight matrix to the predictions
Expand Down
8 changes: 4 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def _do_not_evaluate_in_jax(
'pallas/quickstart.md',
'pallas/tpu/pipelining.md',
'pallas/tpu/distributed.md',
'pallas/tpu/sparse.md',
'pallas/tpu/matmul.md',
'jep/9407-type-promotion.md',
'autodidax.md',
Expand Down Expand Up @@ -168,7 +167,7 @@ def _do_not_evaluate_in_jax(
# documentation.
html_theme_options = {
'show_toc_level': 2,
'repository_url': 'https://github.com/jax-ml/jax',
'repository_url': 'https://github.com/google/jax',
'use_repository_button': True, # add a "link to repository" button
'navigation_with_keys': False,
}
Expand Down Expand Up @@ -225,7 +224,6 @@ def _do_not_evaluate_in_jax(
'pallas/quickstart.*',
'pallas/tpu/pipelining.*',
'pallas/tpu/distributed.*',
'pallas/tpu/sparse.*',
'pallas/tpu/matmul.*',
'sharded-computation.*',
'distributed_data_loading.*'
Expand Down Expand Up @@ -345,7 +343,7 @@ def linkcode_resolve(domain, info):
return None
filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__))
lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else ""
return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}"
return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}"

# Generate redirects from deleted files to new sources
rediraffe_redirects = {
Expand All @@ -360,4 +358,6 @@ def linkcode_resolve(domain, info):
'jax-101/07-state.md': 'stateful-computations.md',
'jax-101/08-pjit.rst': 'sharded-computation.md',
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
}
2 changes: 0 additions & 2 deletions docs/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ that use or interface with JAX.
:caption: Extensible JAX internals
:maxdepth: 1

notebooks/How_JAX_primitives_work
jaxpr
notebooks/Writing_custom_interpreters_in_Jax
Custom_Operation_for_GPUs
jax.extend
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
"We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
]
},
{
Expand Down Expand Up @@ -406,7 +406,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)."
]
},
{
Expand Down Expand Up @@ -492,7 +492,7 @@
"source": [
"At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n",
"One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n",
"JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n",
"JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n",
"\n",
"One other JAX feature that this example doesn't support is higher-order AD.\n",
"It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/ffi.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
```

We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:

```{code-cell} ipython3
jax.make_jaxpr(jax.vmap(rms_norm))(x)
Expand All @@ -333,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5):
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
```

If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues).

+++

Expand Down Expand Up @@ -406,7 +406,7 @@ np.testing.assert_allclose(

At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.
One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.
JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.
JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.

One other JAX feature that this example doesn't support is higher-order AD.
It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.
Expand Down
2 changes: 1 addition & 1 deletion docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Glossary of terms
jaxpr
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
See :ref:`understanding-jaxprs` for more discussion and examples.
See :ref:`jax-internals-jaxpr` for more discussion and examples.

JIT
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
Expand Down
File renamed without changes.
7 changes: 4 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,21 @@ maintains an up-to-date list.

installation
quickstart
notebooks/Common_Gotchas_in_JAX
faq

.. toctree::
:hidden:
:maxdepth: 1

tutorials

notebooks/Common_Gotchas_in_JAX

faq

.. toctree::
:hidden:
:maxdepth: 2
:caption: Resources
:caption: More guides/resources

user_guides
advanced_guide
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 2457601

Please sign in to comment.