Skip to content

Commit

Permalink
Add JAX Advanced Tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Sep 10, 2024
1 parent 2d74c6a commit 0fabd73
Show file tree
Hide file tree
Showing 20 changed files with 22 additions and 4,427 deletions.
5 changes: 0 additions & 5 deletions docs/_tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ JAX 201
:maxdepth: 1

parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance

JAX 301
Expand All @@ -50,6 +47,4 @@ JAX 301
.. toctree::
:maxdepth: 1

jax-primitives
jaxpr
advanced-compilation
File renamed without changes.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,6 @@ def linkcode_resolve(domain, info):
'jax-101/07-state.md': 'stateful-computations.md',
'jax-101/08-pjit.rst': 'sharded-computation.md',
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
}
2 changes: 0 additions & 2 deletions docs/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ that use or interface with JAX.
:caption: Extensible JAX internals
:maxdepth: 1

notebooks/How_JAX_primitives_work
jaxpr
notebooks/Writing_custom_interpreters_in_Jax
Custom_Operation_for_GPUs
jax.extend
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion docs/ffi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
"We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/ffi.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
```

We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:

```{code-cell} ipython3
jax.make_jaxpr(jax.vmap(rms_norm))(x)
Expand Down
2 changes: 1 addition & 1 deletion docs/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Glossary of terms
jaxpr
Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
See :ref:`understanding-jaxprs` for more discussion and examples.
See :ref:`jax-internals-jaxpr` for more discussion and examples.

JIT
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
Expand Down
File renamed without changes.
7 changes: 4 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,21 @@ maintains an up-to-date list.

installation
quickstart
notebooks/Common_Gotchas_in_JAX
faq

.. toctree::
:hidden:
:maxdepth: 1

tutorials

notebooks/Common_Gotchas_in_JAX

faq

.. toctree::
:hidden:
:maxdepth: 2
:caption: Resources
:caption: More guides/resources

user_guides
advanced_guide
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
```

You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`.
You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_lowering`.

```{code-cell}
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
Expand Down
File renamed without changes.
Loading

0 comments on commit 0fabd73

Please sign in to comment.