Skip to content

Commit

Permalink
[pallas] Added more documentation for grid and BlockSpec.
Browse files Browse the repository at this point in the history
The starting point was the text in pipelining.md, where I
replaced it now with a reference to the separate grid and BlockSpec
documentation.

The grids and BlockSpecs are also documented in the quickstart.md,
which I mostly left alone because it was good enough for a
simple example.

I have also attempted to add a few docstrings.
  • Loading branch information
gnecula committed Jun 27, 2024
1 parent 945b1c3 commit 7267ba6
Show file tree
Hide file tree
Showing 14 changed files with 320 additions and 137 deletions.
23 changes: 23 additions & 0 deletions docs/jax.experimental.pallas.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
``jax.experimental.pallas`` module
==================================

.. automodule:: jax.experimental.pallas

Classes
-------

.. autosummary::
:toctree: _autosummary

BlockSpec

Functions
---------

.. autosummary::
:toctree: _autosummary

pallas_call
program_id
num_programs

1 change: 1 addition & 0 deletions docs/jax.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Experimental Modules
jax.experimental.mesh_utils
jax.experimental.serialize_executable
jax.experimental.shard_map
jax.experimental.pallas

Experimental APIs
-----------------
Expand Down
7 changes: 4 additions & 3 deletions docs/pallas/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,10 @@ The signature of `pallas_call` is as follows:
```python
def pallas_call(
kernel: Callable,
out_shape: Sequence[jax.ShapeDtypeStruct],
*,
in_specs: Sequence[Spec],
out_specs: Sequence[Spec],
out_shapes: Sequence[jax.ShapeDtypeStruct],
grid: Optional[Tuple[int, ...]] = None) -> Callable:
...
```
Expand All @@ -303,9 +304,9 @@ information about how the kernel will be scheduled on the accelerator.
The (rough) semantics for `pallas_call` are as follows:

```python
def pallas_call(kernel, in_specs, out_specs, out_shapes, grid):
def pallas_call(kernel, out_shape, *, in_specs, out_specs, grid):
def execute(*args):
outputs = map(empty_ref, out_shapes)
outputs = map(empty_ref, out_shape)
grid_indices = map(range, grid)
for indices in itertools.product(*grid_indices): # Could run in parallel!
local_inputs = [in_spec.transform(arg, indices) for arg, in_spec in
Expand Down
202 changes: 202 additions & 0 deletions docs/pallas/grid_blockspec.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
(pallas_grids_and_blockspecs)=

# Grids and BlockSpecs

(pallas_grid)=
### `grid`, a.k.a. kernels in a loop

When using {func}`jax.experimental.pallas.pallas_call` the kernel function
is executed multiple times on different inputs, as specified via the `grid` argument
to `pallas_call`. Conceptually:
```python
pl.pallas_call(some_kernel, grid=(n,))(...)
```
maps to
```python
for i in range(n):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
```
Grids can be generalized to be multi-dimensional, corresponding to nested
loops. For example,

```python
pl.pallas_call(some_kernel, grid=(n, m))(...)
```
is equivalent to
```python
for i in range(n):
for j in range(m):
# do HBM -> VMEM copies
some_kernel(...)
# do VMEM -> HBM copies
```
This generalizes to any tuple of integers (a length `d` grid will correspond
to `d` nested loops).
The kernel is executed as many times
as `prod(grid)`. Each of these invocations is referred to as a "program".
To access which program (i.e. which element of the grid) the kernel is currently
executing, we use {func}`jax.experimental.pallas.program_id`.
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and
`program_id(axis=1)` returns `2`.
You can also use {func}`jax.experimental.pallas.num_programs` to get the
grid size for a given axis.

Here's an example kernel that uses a `grid` and `program_id`.

```python
>>> import jax
>>> from jax.experimental import pallas as pl

>>> def iota_kernel(o_ref):
... i = pl.program_id(0)
... o_ref[i] = i

```

We now execute it using `pallas_call` with an additional `grid` argument.

```python
>>> def iota(len: int):
... return pl.pallas_call(iota_kernel,
... out_shape=jax.ShapeDtypeStruct((len,), jnp.int32),
... grid=(len,), interpret=True)()
>>> iota(8)
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

```

On GPUs, each program is executed in parallel on separate threads.
Thus, we need to think about race conditions on writes to HBM.
A reasonable approach is to write our kernels in such a way that different
programs write to disjoint places in HBM to avoid these parallel writes.

On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions).

(pallas_blockspec)=

### `BlockSpec`, a.k.a. how to chunk up inputs

In conjunction with the `grid` argument, we need to provide Pallas
the information on how to slice up the input for each grid element.
Specifically, we need to provide a mapping between *the iteration of the loop*
to *which block of our inputs and outputs to be operated on*.
This is provided via {class}`jax.experimental.pallas.BlockSpec` objects.

Before we get into the details of `BlockSpec`s, you may want
to revisit the
[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example).

`BlockSpec`s are provided to `pallas_call` via the
`in_specs` and `out_specs`, one for each input and output respectively.

Informally, the `index_map` of the `BlockSpec` takes as arguments
the invocation indices (as many as the length of the `grid` tuple),
and returns **block indices** (one block index for each axis of
the overall array). Each block index is then multiplied by the
corresponding axis size from `block_shape`
to get the actual element index on the corresponding array axis.
If the resulting element index is such that the block would be
out of bounds, the element index is reduced so that the entire block
fits in bounds.

More precisely, the slices for each axis of the input `x` of
shape `x_shape` are computed as in the function `slice_for_invocation`
below:

```python
>>> def slices_for_invocation(x_shape: tuple[int, ...],
... x_spec: pl.BlockSpec,
... grid: tuple[int, ...],
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]:
... assert len(invocation_indices) == len(grid)
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid))
... block_indices = x_spec.index_map(* invocation_indices)
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices)
... elem_indices = []
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices):
... assert block_size <= x_size # Blocks must be smaller than the array
... in_bounds_start_idx = min(block_idx * block_size, x_size - block_size)
... elem_indices.append(slice(in_bounds_start_idx, in_bounds_start_idx + block_size))
... return elem_indices

```

For example:
```python
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)),
... grid = (10, 5),
... invocation_indices = (2, 3))
[slice(20, 30, None), slice(60, 80, None)]

>>> # This example shows out-of-bounds adjustment on axis=1
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)),
... grid = (10, 10),
... invocation_indices = (2, 8))
[slice(20, 30, None), slice(80, 100, None)]

>>> # Same shape of the array and blocks, but we iterate over each block 4 times
>>> slices_for_invocation(x_shape=(100, 100),
... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)),
... grid = (10, 5, 4),
... invocation_indices = (2, 3, 0))
[slice(20, 30, None), slice(60, 80, None)]

```

The function `show_invocations` defined below uses Pallas to show the
invocation indices. The `iota_2D_kernel` will fill each output block
with a decimal number where the first digit represents the invocation
index over the first axis, and the second the invocation index
over the second axis:

```python
>>> def show_invocations(x_shape, block_shape, grid):
... def iota_2D_kernel(o_ref):
... axes = 10 * pl.program_id(0) + pl.program_id(1)
... o_ref[...] = jnp.full(o_ref.shape, axes)
... res = pl.pallas_call(iota_2D_kernel,
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32),
... grid=grid,
... in_specs=[],
... out_specs=pl.BlockSpec(lambda i, j: (i, j), block_shape),
... interpret=True)()
... print(res)

```

For example:
```python
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2))
[[ 0 0 0 1 1 1]
[ 0 0 0 1 1 1]
[10 10 10 11 11 11]
[10 10 10 11 11 11]
[20 20 20 21 21 21]
[20 20 20 21 21 21]
[30 30 30 31 31 31]
[30 30 30 31 31 31]]

```

When multiple invocations write to the same elements of the output
array the result is platform dependent. On the CPU platform the
invocations are executed in order, so you will see the last invocation
that wrote to the block, as in the last 2 rows and last 3 columns of the output below:

```python
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(10, 10))
[[ 0 0 0 9 9 9]
[ 0 0 0 9 9 9]
[10 10 10 19 19 19]
[10 10 10 19 19 19]
[20 20 20 29 29 29]
[20 20 20 29 29 29]
[90 90 90 99 99 99]
[90 90 90 99 99 99]]
```
4 changes: 3 additions & 1 deletion docs/pallas/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ Pallas: a JAX kernel language
=============================
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.
This section contains tutorials, guides and examples for using Pallas.
See also the :class:`jax.experimental.pallas` module API documentation.

.. toctree::
:caption: Guides
:maxdepth: 2

design
quickstart
design
grid_blockspec

.. toctree::
:caption: Platform Features
Expand Down
12 changes: 8 additions & 4 deletions docs/pallas/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"\n",
"Let's dissect this function a bit. Unlike most JAX functions you've probably written,\n",
"it does not take in `jax.Array`s as inputs and doesn't return any values.\n",
"Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
"Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs\n",
"but we are given an `o_ref`, which corresponds to the desired output.\n",
"\n",
"**Reading from `Ref`s**\n",
Expand Down Expand Up @@ -194,7 +194,7 @@
"live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations\n",
"that operate on \"blocks\" of those arrays that can fit in SRAM.\n",
"\n",
"### Grids\n",
"### Grids by example\n",
"\n",
"To automatically \"carve\" up the inputs and outputs, you provide a `grid` and\n",
"`BlockSpec`s to `pallas_call`.\n",
Expand Down Expand Up @@ -279,15 +279,17 @@
"operations like matrix multiplications really quickly.\n",
"\n",
"On TPUs, programs are executed in a combination of parallel and sequential\n",
"(depending on the architecture) so there are slightly different considerations."
"(depending on the architecture) so there are slightly different considerations.\n",
"\n",
"You can read more details at {ref}`pallas_grid`."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Block specs"
"### Block specs by example"
]
},
{
Expand Down Expand Up @@ -385,6 +387,8 @@
"\n",
"These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.\n",
"\n",
"For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.\n",
"\n",
"Underneath the hood, `pallas_call` will automatically carve up your inputs and\n",
"outputs into `Ref`s for each block that will be passed into the kernel."
]
Expand Down
10 changes: 7 additions & 3 deletions docs/pallas/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def add_vectors_kernel(x_ref, y_ref, o_ref):

Let's dissect this function a bit. Unlike most JAX functions you've probably written,
it does not take in `jax.Array`s as inputs and doesn't return any values.
Instead it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
Instead, it takes in *`Ref`* objects as inputs. Note that we also don't have any outputs
but we are given an `o_ref`, which corresponds to the desired output.

**Reading from `Ref`s**
Expand Down Expand Up @@ -133,7 +133,7 @@ Part of writing Pallas kernels is thinking about how to take big arrays that
live in high-bandwidth memory (HBM, also known as DRAM) and expressing computations
that operate on "blocks" of those arrays that can fit in SRAM.

### Grids
### Grids by example

To automatically "carve" up the inputs and outputs, you provide a `grid` and
`BlockSpec`s to `pallas_call`.
Expand Down Expand Up @@ -187,9 +187,11 @@ operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.

You can read more details at {ref}`pallas_grid`.

+++

### Block specs
### Block specs by example

+++

Expand Down Expand Up @@ -279,6 +281,8 @@ Finally, for `z` we use `BlockSpec(lambda i, j: (i, j), (512, 512))`.

These `BlockSpec`s are passed into `pallas_call` via `in_specs` and `out_specs`.

For more detail on `BlockSpec`s see {ref}`pallas_blockspec`.

Underneath the hood, `pallas_call` will automatically carve up your inputs and
outputs into `Ref`s for each block that will be passed into the kernel.

Expand Down
3 changes: 2 additions & 1 deletion docs/pallas/tpu/details.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Noteworthy properties and restrictions
``BlockSpec``\s and grid iteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

``BlockSpec``\s generally behave as expected in Pallas --- every invocation of
``BlockSpec``\s (see :ref:`pallas_blockspec`) generally behave as expected
in Pallas --- every invocation of
the kernel body gets access to slices of the inputs and is meant to initialize a slice
of the output.

Expand Down
Loading

0 comments on commit 7267ba6

Please sign in to comment.