-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pallas] Added more documentation for grid and BlockSpec.
The starting point was the text in pipelining.md, where I replaced it now with a reference to the separate grid and BlockSpec documentation. The grids and BlockSpecs are also documented in the quickstart.md, which I mostly left alone because it was good enough for a simple example. I have also attempted to add a few docstrings.
- Loading branch information
Showing
14 changed files
with
320 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
``jax.experimental.pallas`` module | ||
================================== | ||
|
||
.. automodule:: jax.experimental.pallas | ||
|
||
Classes | ||
------- | ||
|
||
.. autosummary:: | ||
:toctree: _autosummary | ||
|
||
BlockSpec | ||
|
||
Functions | ||
--------- | ||
|
||
.. autosummary:: | ||
:toctree: _autosummary | ||
|
||
pallas_call | ||
program_id | ||
num_programs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
(pallas_grids_and_blockspecs)= | ||
|
||
# Grids and BlockSpecs | ||
|
||
(pallas_grid)= | ||
### `grid`, a.k.a. kernels in a loop | ||
|
||
When using {func}`jax.experimental.pallas.pallas_call` the kernel function | ||
is executed multiple times on different inputs, as specified via the `grid` argument | ||
to `pallas_call`. Conceptually: | ||
```python | ||
pl.pallas_call(some_kernel, grid=(n,))(...) | ||
``` | ||
maps to | ||
```python | ||
for i in range(n): | ||
# do HBM -> VMEM copies | ||
some_kernel(...) | ||
# do VMEM -> HBM copies | ||
``` | ||
Grids can be generalized to be multi-dimensional, corresponding to nested | ||
loops. For example, | ||
|
||
```python | ||
pl.pallas_call(some_kernel, grid=(n, m))(...) | ||
``` | ||
is equivalent to | ||
```python | ||
for i in range(n): | ||
for j in range(m): | ||
# do HBM -> VMEM copies | ||
some_kernel(...) | ||
# do VMEM -> HBM copies | ||
``` | ||
This generalizes to any tuple of integers (a length `d` grid will correspond | ||
to `d` nested loops). | ||
The kernel is executed as many times | ||
as `prod(grid)`. Each of these invocations is referred to as a "program". | ||
To access which program (i.e. which element of the grid) the kernel is currently | ||
executing, we use {func}`jax.experimental.pallas.program_id`. | ||
For example, for invocation `(1, 2)`, `program_id(axis=0)` returns `1` and | ||
`program_id(axis=1)` returns `2`. | ||
You can also use {func}`jax.experimental.pallas.num_programs` to get the | ||
grid size for a given axis. | ||
|
||
Here's an example kernel that uses a `grid` and `program_id`. | ||
|
||
```python | ||
>>> import jax | ||
>>> from jax.experimental import pallas as pl | ||
|
||
>>> def iota_kernel(o_ref): | ||
... i = pl.program_id(0) | ||
... o_ref[i] = i | ||
|
||
``` | ||
|
||
We now execute it using `pallas_call` with an additional `grid` argument. | ||
|
||
```python | ||
>>> def iota(len: int): | ||
... return pl.pallas_call(iota_kernel, | ||
... out_shape=jax.ShapeDtypeStruct((len,), jnp.int32), | ||
... grid=(len,), interpret=True)() | ||
>>> iota(8) | ||
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32) | ||
|
||
``` | ||
|
||
On GPUs, each program is executed in parallel on separate threads. | ||
Thus, we need to think about race conditions on writes to HBM. | ||
A reasonable approach is to write our kernels in such a way that different | ||
programs write to disjoint places in HBM to avoid these parallel writes. | ||
|
||
On TPUs, programs are executed in a combination of parallel and sequential | ||
(depending on the architecture) so there are slightly different considerations. | ||
See [the Pallas TPU documentation](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#noteworthy-properties-and-restrictions). | ||
|
||
(pallas_blockspec)= | ||
|
||
### `BlockSpec`, a.k.a. how to chunk up inputs | ||
|
||
In conjunction with the `grid` argument, we need to provide Pallas | ||
the information on how to slice up the input for each grid element. | ||
Specifically, we need to provide a mapping between *the iteration of the loop* | ||
to *which block of our inputs and outputs to be operated on*. | ||
This is provided via {class}`jax.experimental.pallas.BlockSpec` objects. | ||
|
||
Before we get into the details of `BlockSpec`s, you may want | ||
to revisit the | ||
[Pallas Quickstart BlockSpecs example](https://jax.readthedocs.io/en/latest/pallas/quickstart.html#block-specs-by-example). | ||
|
||
`BlockSpec`s are provided to `pallas_call` via the | ||
`in_specs` and `out_specs`, one for each input and output respectively. | ||
|
||
Informally, the `index_map` of the `BlockSpec` takes as arguments | ||
the invocation indices (as many as the length of the `grid` tuple), | ||
and returns **block indices** (one block index for each axis of | ||
the overall array). Each block index is then multiplied by the | ||
corresponding axis size from `block_shape` | ||
to get the actual element index on the corresponding array axis. | ||
If the resulting element index is such that the block would be | ||
out of bounds, the element index is reduced so that the entire block | ||
fits in bounds. | ||
|
||
More precisely, the slices for each axis of the input `x` of | ||
shape `x_shape` are computed as in the function `slice_for_invocation` | ||
below: | ||
|
||
```python | ||
>>> def slices_for_invocation(x_shape: tuple[int, ...], | ||
... x_spec: pl.BlockSpec, | ||
... grid: tuple[int, ...], | ||
... invocation_indices: tuple[int, ...]) -> tuple[slice, ...]: | ||
... assert len(invocation_indices) == len(grid) | ||
... assert all(0 <= i < grid_size for i, grid_size in zip(invocation_indices, grid)) | ||
... block_indices = x_spec.index_map(* invocation_indices) | ||
... assert len(x_shape) == len(x_spec.block_shape) == len(block_indices) | ||
... elem_indices = [] | ||
... for x_size, block_size, block_idx in zip(x_shape, x_spec.block_shape, block_indices): | ||
... assert block_size <= x_size # Blocks must be smaller than the array | ||
... in_bounds_start_idx = min(block_idx * block_size, x_size - block_size) | ||
... elem_indices.append(slice(in_bounds_start_idx, in_bounds_start_idx + block_size)) | ||
... return elem_indices | ||
|
||
``` | ||
|
||
For example: | ||
```python | ||
>>> slices_for_invocation(x_shape=(100, 100), | ||
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)), | ||
... grid = (10, 5), | ||
... invocation_indices = (2, 3)) | ||
[slice(20, 30, None), slice(60, 80, None)] | ||
|
||
>>> # This example shows out-of-bounds adjustment on axis=1 | ||
>>> slices_for_invocation(x_shape=(100, 100), | ||
... x_spec = pl.BlockSpec(lambda i, j: (i, j), (10, 20)), | ||
... grid = (10, 10), | ||
... invocation_indices = (2, 8)) | ||
[slice(20, 30, None), slice(80, 100, None)] | ||
|
||
>>> # Same shape of the array and blocks, but we iterate over each block 4 times | ||
>>> slices_for_invocation(x_shape=(100, 100), | ||
... x_spec = pl.BlockSpec(lambda i, j, k: (i, j), (10, 20)), | ||
... grid = (10, 5, 4), | ||
... invocation_indices = (2, 3, 0)) | ||
[slice(20, 30, None), slice(60, 80, None)] | ||
|
||
``` | ||
|
||
The function `show_invocations` defined below uses Pallas to show the | ||
invocation indices. The `iota_2D_kernel` will fill each output block | ||
with a decimal number where the first digit represents the invocation | ||
index over the first axis, and the second the invocation index | ||
over the second axis: | ||
|
||
```python | ||
>>> def show_invocations(x_shape, block_shape, grid): | ||
... def iota_2D_kernel(o_ref): | ||
... axes = 10 * pl.program_id(0) + pl.program_id(1) | ||
... o_ref[...] = jnp.full(o_ref.shape, axes) | ||
... res = pl.pallas_call(iota_2D_kernel, | ||
... out_shape=jax.ShapeDtypeStruct(x_shape, dtype=np.int32), | ||
... grid=grid, | ||
... in_specs=[], | ||
... out_specs=pl.BlockSpec(lambda i, j: (i, j), block_shape), | ||
... interpret=True)() | ||
... print(res) | ||
|
||
``` | ||
|
||
For example: | ||
```python | ||
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(4, 2)) | ||
[[ 0 0 0 1 1 1] | ||
[ 0 0 0 1 1 1] | ||
[10 10 10 11 11 11] | ||
[10 10 10 11 11 11] | ||
[20 20 20 21 21 21] | ||
[20 20 20 21 21 21] | ||
[30 30 30 31 31 31] | ||
[30 30 30 31 31 31]] | ||
|
||
``` | ||
|
||
When multiple invocations write to the same elements of the output | ||
array the result is platform dependent. On the CPU platform the | ||
invocations are executed in order, so you will see the last invocation | ||
that wrote to the block, as in the last 2 rows and last 3 columns of the output below: | ||
|
||
```python | ||
>>> show_invocations(x_shape=(8, 6), block_shape=(2, 3), grid=(10, 10)) | ||
[[ 0 0 0 9 9 9] | ||
[ 0 0 0 9 9 9] | ||
[10 10 10 19 19 19] | ||
[10 10 10 19 19 19] | ||
[20 20 20 29 29 29] | ||
[20 20 20 29 29 29] | ||
[90 90 90 99 99 99] | ||
[90 90 90 99 99 99]] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.