Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: replace old tutorials with new content #20819

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)

**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Advanced compilation

```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

For the time being, you may find some related content in the old documentation:
- {doc}`../aot`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ kernelspec:
(advanced-debugging)=
# Advanced debugging
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

For the time being, you may find some related content in the old documentation:
- {doc}`../debugging/index`
Expand Down
57 changes: 57 additions & 0 deletions docs/_tutorials/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
:orphan:

.. _jax-tutorials-draft:

JAX tutorials draft
===================

.. note::

This is a
The tutorials below are a work in progress; for the time being, please refer
to the older tutorial content, including :ref:`beginner-guide`,
:ref:`user-guides`, and the now-deleted *JAX 101* tutorials.

JAX 101
-------
Mostly finalized at :ref:`jax-tutorials`!

.. toctree::
:maxdepth: 1

../quickstart
../key-concepts
../jit-compilation
../automatic-vectorization
../automatic-differentiation
../debugging
../random-numbers
../working-with-pytrees
../sharded-computation
../stateful-computations
simple-neural-network


JAX 201
-------

.. toctree::
:maxdepth: 1

parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance


JAX 301
-------

.. toctree::
:maxdepth: 1

jax-primitives
jaxpr
advanced-compilation
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Parallel computation

```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

For the time being, you may find some related content in the old documentation:
- {doc}`../multi_process`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Profiling and performance

```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.

For the time being, you may find some related content in the old documentation:
- {doc}`../profiling`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Example: Writing a simple neural network

```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
```
6 changes: 3 additions & 3 deletions docs/beginner_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Getting Started with JAX
========================
Welcome to JAX! The JAX documentation contains a number of useful resources for getting started.
:doc:`notebooks/quickstart` is the easiest place to jump-in and get an overview of the JAX project.
:doc:`quickstart` is the easiest place to jump-in and get an overview of the JAX project.

If you're accustomed to writing NumPy code and are starting to explore JAX, you might find the following resources helpful:

Expand All @@ -15,12 +15,12 @@ If you're accustomed to writing NumPy code and are starting to explore JAX, you

Tutorials
---------
If you're ready to explore JAX more deeply, the JAX 101 tutorial goes into much more detail:
If you're ready to explore JAX more deeply, the JAX tutorials go into much more detail:

.. toctree::
:maxdepth: 2

jax-101/index
tutorials

If you prefer a video introduction here is one from JAX contributor Jake VanderPlas:

Expand Down
2 changes: 1 addition & 1 deletion docs/building_on_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Here are more specific examples of each pattern.

### Direct Usage
Jax can be directly imported and utilized to build models “from scratch” as shown across this website,
for example in [JAX 101](https://jax.readthedocs.io/en/latest/jax-101/index.html)
for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html)
or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html).
This may be the best option if you are unable to find prebuilt code
for your particular challenge, or if you're looking to reduce the number
Expand Down
27 changes: 20 additions & 7 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@
"sphinx_remove_toctrees",
'sphinx_copybutton',
'jax_extensions',
'sphinx_design'
'sphinx_design',
'sphinxext.rediraffe',
]

intersphinx_mapping = {
Expand Down Expand Up @@ -124,9 +125,8 @@
'pallas/quickstart.md',
'pallas/tpu/pipelining.md',
'jep/9407-type-promotion.md',
'jax-101/*.md',
'autodidax.md',
'tutorials/sharded-computation.md',
'sharded-computation.md',
]

# The name of the Pygments (syntax highlighting) style to use.
Expand Down Expand Up @@ -198,23 +198,20 @@
# List of patterns, relative to source directory, that match notebook
# files that will not be executed.
nb_execution_excludepatterns = [
# Includes GPU timings that shouldn't be executed by doc build
'notebooks/quickstart.*',
# Slow notebook: long time to load tf.ds
'notebooks/neural_network_with_tfds_data.*',
# Slow notebook
'notebooks/Neural_Network_and_Data_Loading.*',
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
'jep/9407-type-promotion.*',
# TODO(jakevdp): enable execution on the following if possible:
'jax-101/*',
'notebooks/xmap_tutorial.*',
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
'notebooks/autodiff_remat.*',
# Requires accelerators
'pallas/quickstart.*',
'pallas/tpu/pipelining.*',
'tutorials/sharded-computation.*'
'sharded-computation.*'
]

# -- Options for HTMLHelp output ---------------------------------------------
Expand Down Expand Up @@ -308,3 +305,19 @@

# Remove auto-generated API docs from sidebars. They take too long to build.
remove_from_toctrees = ["_autosummary/*"]


# Generate redirects from deleted files to new sources
rediraffe_redirects = {
'notebooks/quickstart.md': 'quickstart.md',
'jax-101/01-jax-basics.md': 'key-concepts.md',
'jax-101/02-jitting.md': 'jit-compilation.md',
'jax-101/03-vectorization.md': 'automatic-vectorization.md',
'jax-101/04-advanced-autodiff.md': 'automatic-differentiation.md',
'jax-101/05-random-numbers.md': 'random-numbers.md',
'jax-101/05.1-pytrees.md': 'working-with-pytrees.md',
'jax-101/06-parallelism.md': 'sharded-computation.md',
'jax-101/07-state.md': 'stateful-computations.md',
'jax-101/08-pjit.rst': 'sharded-computation.md',
'jax-101/index.rst': 'tutorials.rst',
}
2 changes: 1 addition & 1 deletion docs/tutorials/debugging.md → docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def f(x):
f(2.) # ==> Pauses during execution
```

![JAX debugger](../_static/debugger.gif)
![JAX debugger](_static/debugger.gif)

For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond`:

Expand Down
2 changes: 1 addition & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ notebooks; for example:

```
pip install jupytext==1.16.0
jupytext --sync docs/notebooks/quickstart.ipynb
jupytext --sync docs/notebooks/thinking_in_jax.ipynb
```

The jupytext version should match that specified in
Expand Down
5 changes: 2 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ For an end-to-end transformer library built on JAX, see MaxText_.
:caption: Getting Started

installation
notebooks/quickstart
notebooks/thinking_in_jax
quickstart
notebooks/Common_Gotchas_in_JAX
faq

.. toctree::
:hidden:
:maxdepth: 1

jax-101/index
tutorials


.. toctree::
Expand Down
Loading