Skip to content

Commit

Permalink
quickstart tweaks (from jax-ml#20819)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 19, 2024
1 parent 32922f6 commit 94e3a6e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions docs/tutorials/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ kernelspec:

# Quickstart

**JAX a library for array-oriented numerical computation (*a la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.

This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:

Expand Down Expand Up @@ -125,16 +125,18 @@ In the above example we jitted `sum_logistic` and then took its derivative. We c
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
```

Similarly, the {func}`jax.jacobian` transformation can be used to compute gradients of vector-valued functions:
Beyond scalar-valued functions, the {func}`jax.jacobian` transformation can be
used to compute the full Jacobian matrix for vector-valued functions:

```{code-cell}
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
```

For more advanced autodiff operations, you can use {func}`jax..jacrev` for reverse-mode vector-Jacobian products,
and {func}`jax.jacfwd` for forward-mode Jacobian-vector products.
For more advanced autodiff operations, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products,
and {func}`jax.jvp` and {func}`jax.linearize` for forward-mode Jacobian-vector products.
The two can be composed arbitrarily with one another, and with other JAX transformations.
For example, {func}`jax.jvp` and {func}`jax.vjp` are used to define the forward-mode {func}`jax.jacfwd` and reverse-mode {func}`jax.jacrev` for computing Jacobians in forward- and reverse-mode, respectively.
Here's one way to compose them to make a function that efficiently computes full Hessian matrices:

```{code-cell}
Expand Down

0 comments on commit 94e3a6e

Please sign in to comment.