Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 164 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,179 @@
<div align="center">
# Kraken

[**🎯 Features**](#-features) | [**🚀 Getting Started**](#-getting-started) | [**💻 Usage**](#-usage) | [**Benchmarks**](#-benchmarks) | [**🤝 Contributing**](#-contributing) | [**⚖️ License**](#️-license)

#### A Triton library of Symmetric Memory operators and examples.

</div>
This repository aims to be a cookbook for developing distributed AI models using Triton and PyTorch's symmetric memory capabilities.

This is NOT intended to be a "framework" or "library" - it is intended to provide some high-performance Triton implementations with in-kernel communication for developers to hack on :) Please copy-paste and fork as you desire.


This repository aims to simplify the process of developing distributed AI models using Triton and PyTorch's symmetric memory capabilities. Our initial kernels are adapted from the [Symmetric Memory Recipes](https://github.com/yifuwang/symm-mem-recipes) by Yifu Wang.
In additional to that, it includes a set of benchmarks to help researchers and developers explore and evaluate their implmentations.

## Examples
TBD
Our initial kernels are adapted from the [Symmetric Memory Recipes](https://github.com/yifuwang/symm-mem-recipes) by Yifu Wang.

## Requirements
Kraken requires:
* Triton >= 3.3.0
* PyTorch >= 2.6.0
* Python >= 3.10
## 🎯 Features
- Receipe for high-performance Triton implementations of `all_gather`, `all_reduce`, and `reduce_scatter`.
- Comm-comp fused kernels such as `gemm_one_shot_all_reduce_fused` for increased efficiency.
- A suite of benchmarks to measure and compare the performance of different comm + comp implementations.
- PTX utilities for synchronization primitives not yet supported by Triton.

## Installation
## 🚀 Getting Started
### Prerequisites
- PyTorch (version 2.6.0 or higher)
- Triton (version 3.3.0)
- Python (version 3.10 or higher)
- CUDA (version 12.4 or higher) Version must matche your PyTorch installaltion.

### Installation
```bash
git clone https://github.com/meta-pytorch/kraken
cd kraken
pip install -e . -r requirements.txt
```

## License
Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository.
## 💻 Usage
Rather than a rigid framework, Kraken is a hands-on tutorial: developers can embed its techniques into xformers, FlashAttention, TorchInductor-generated kernels—or any custom Triton code.

There are two ways of using Kraken kernels:


You can import and use the Kraken kernels in your own PyTorch projects. Here is an example of how to use the `one_shot_all_reduce` kernel:

```python
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import kraken
import os

# setup distributed process group.
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(f"cuda:{local_rank}")
dist.init_process_group("nccl")

# Create and initialize a symmetric memory tensor
# See blog: https://dev-discuss.pytorch.org/t/pytorch-symmetricmemory-harnessing-nvlink-programmability-with-ease/279 for symmetric memory details.
a_shared = symm_mem.empty(
(4096, 4096),
dtype=torch.bfloat16,
device=f"cuda:{local_rank}",
)
symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
a_shared = a_shared.normal_()

# Call one_shot_all_reduce kernel from kraken.
a = kraken.comm.one_shot_all_reduce(a_shared)
```
Remember to run with torchrun! Example torchrun command:
```shell
torchrun --nnodes 1 --nproc-per-node <world_size> \
--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python \
python3 example.py
```

Alternatively, you can build your own custom kernels by leveraging Kraken's low-level primitives. This allows you to create highly optimized kernels tailored to your specific needs. We provide PTX implementations of low-level primitives in `kraken._ptx_utils`.

Here's an example of how to use `kraken._ptx_utils.symm_mem_sync` to synchronize blocks with matching `block_id` across participating devices in a custom kernel. This is often necessary before and after accessing symmetric memory tensors.

```python
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import triton
import triton.language as tl

import kraken
import os

@triton.jit
def custom_distributed_kernel(
a_shared_ptrs,
a_signal_pad_ptrs,
rank: tl.constexpr,
world_size: tl.constexpr,
):
# Synchronizes blocks with matching block_id across participating devices.
# Ensures that all writes to a_shared from previous kernels across all devices
# are visible to the current kernel:
kraken._ptx_utils.symm_mem_sync(
a_signal_pad_ptrs,
None,
rank,
world_size,
hasPreviousMemAccess=False,
hasSubsequentMemAccess=True,
)
... # access a_shared via a_shared_ptrs.

# Create and initialize a symmetric memory tensor
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(f"cuda:{local_rank}")
dist.init_process_group("nccl")
a_shared = symm_mem.empty((4096, 4096), dtype=torch.bfloat16, device=f"cuda:{local_rank}")
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)

# Define the grid for kernel launch. For simplicity, we use a single thread block.
grid = (1,)

# Call custom kernel
custom_distributed_kernel[grid](
symm_mem_hdl.buffer_ptrs_dev,
symm_mem_hdl.signal_pad_ptrs_dev,
rank=symm_mem_hdl.rank,
world_size=symm_mem_hdl.world_size,
)
```


## 📁 Structure
Kraken is organized for easy hacking of distributed Triton kernel:

### Example Kernels
#### `kraken.comm`
contains communication kernels with fine-grained sychronizations.
- `all_gather_w_progress`
- `one_shot_all_reduce`
- (coming soon) `two_shot_all_reduce`
- (coming soon) `multimem_all_reduce`
#### `kraken.fused`
Fused communication/computation kernels.
- All gather matmul: `all_gather_matmul`
- Gemm all reduce: `gemm_one_shot_all_reduce_fused`
- Gemm reduce scatter: `gemm_reduce_scatter`, `gemm_reduce_scatter_ce_persistent`
- Reduce bias: `one_shot_all_reduce_bias`, `two_shot_all_reduce_bias`
- Reduce bias rms_norm: `one_shot_all_reduce_bias_rms_norm`, `two_shot_all_reduce_bias_rms_norm`

#### `kraken.quantized`
(comming soon) Fused communication/computation kernels with quantization.


### Inline PTX Utils
`kraken._ptx_utils` provides inline ptx implementation of memory barrier synchorinzations that are not natively supported by triton.



### Benchmarks
Kraken includes a set of benchmarks in `benchmarks/` to evaluate the performance of its kernels. You can run them as follows:

```bash
torchrun --nnodes 1 --nproc-per-node <world_size> \
--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python python3 \
benchmark/benchmark_all_reduce.py
# ... and so on for other benchmarks
```

Run with `--help` to see configurable benchmark arguments for setting backends, dtype, shape etc. to profile.
```bash
python benchmark/benchmark_all_reduce.py --help
```


## 🤝 Contributing
Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more details on how to contribute to the project.

## ⚖️ License
Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository.
42 changes: 11 additions & 31 deletions benchmark/benchmark_all_gather_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import csv
from dataclasses import asdict, dataclass
import functools
import itertools
import os
import sys

Expand Down Expand Up @@ -63,15 +62,10 @@ def asdict(self):

def generate_experiment_configs(
dtype: torch.dtype,
M: list[int],
N: list[int],
K: list[int],
shapes: list[tuple[int, int, int]],
backends: list[str],
device: torch.device,
) -> list[ExperimentConfig]:
# Generate cross config shapes from M, N, K lists
shapes = list(itertools.product(M, N, K))

all_configs = []
for shape in shapes:
all_configs.append(
Expand All @@ -93,7 +87,7 @@ def get_single_backend_fn(backend: str):
if backend == "torch_symm_mem":
return torch_symm_mem_ag_mm
if backend == "triton":
return kraken.all_gather.all_gather_matmul
return kraken.fused.all_gather_matmul
raise NotImplementedError(backend)


Expand Down Expand Up @@ -176,9 +170,7 @@ def main(args):
torch.manual_seed(42 + local_rank)

results = []
configs = generate_experiment_configs(
args.dtype, args.M, args.N, args.K, args.backend, device
)
configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
for config in configs:
results.append(
Experiment(
Expand All @@ -196,7 +188,7 @@ def shape_input_type(s):
M, N, K = map(int, s.split(","))
return M, N, K
except Exception as e:
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
raise argparse.ArgumentTypeError("Shape must be M, N, K") from e


if __name__ == "__main__":
Expand Down Expand Up @@ -228,27 +220,15 @@ def shape_input_type(s):
)

parser.add_argument(
"-M",
type=shape_input_type,
nargs="+",
default=[2**x for x in range(7, 11)],
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
)

parser.add_argument(
"-N",
"--shape",
type=shape_input_type,
nargs="+",
default=[6656],
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
)

parser.add_argument(
"-K",
type=shape_input_type,
nargs="+",
default=[2**x for x in range(12, 15)],
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
default=[
(m, 6656, k)
for m in [2**x for x in range(7, 11)]
for k in [2**x for x in range(12, 16)]
],
help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
)

parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")
Expand Down
2 changes: 1 addition & 1 deletion benchmark/benchmark_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_single_backend_fn(backend: str):
if backend == "dist_2shot":
return symm_mem_two_shot_all_reduce
if backend == "triton_1shot":
return kraken.all_reduce.one_shot_all_reduce
return kraken.comm.one_shot_all_reduce
if backend == "nccl":
return nccl_ring
raise NotImplementedError(backend)
Expand Down
9 changes: 3 additions & 6 deletions benchmark/benchmark_all_reduce_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,23 @@
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import kraken
from kraken import _logging as log
from kraken.all_reduce_fusion import (
one_shot_all_reduce_bias,
two_shot_all_reduce_bias,
)


def one_shot_all_reduce_bias(
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
) -> torch.Tensor:
y = torch.empty_like(x)
one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return y


def two_shot_all_reduce_bias(
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
) -> torch.Tensor:
y = torch.empty_like(x)
two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return y


Expand Down
26 changes: 12 additions & 14 deletions benchmark/benchmark_all_reduce_bias_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,42 @@
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import kraken
from kraken import _logging as log
from kraken.all_reduce_fusion import (
rms_norm,
one_shot_all_reduce_bias,
one_shot_all_reduce_bias_rms_norm,
two_shot_all_reduce_bias,
two_shot_all_reduce_bias_rms_norm,
)


def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
kraken.fused.one_shot_all_reduce_bias_rms_norm(
symm_mem_input, x, bias, rms_weight, y
)
return y


def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return rms_norm(y, rms_weight)
kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return kraken.fused.rms_norm(y, rms_weight)


def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
kraken.fused.two_shot_all_reduce_bias_rms_norm(
symm_mem_input, x, bias, rms_weight, y
)
return y


def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return rms_norm(y, rms_weight)
kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return kraken.fused.rms_norm(y, rms_weight)


def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight):
dist.all_reduce(x)
y = x + bias
return rms_norm(y, rms_weight)
return kraken.fused.rms_norm(y, rms_weight)


def create_benchmarks(b, t, d_size, device, dtype):
Expand Down
Loading