diff --git a/README.md b/README.md
index 45f4a4f..f431b2e 100644
--- a/README.md
+++ b/README.md
@@ -1,27 +1,179 @@
-
# Kraken
+[**π― Features**](#-features) | [**π Getting Started**](#-getting-started) | [**π» Usage**](#-usage) | [**Benchmarks**](#-benchmarks) | [**π€ Contributing**](#-contributing) | [**βοΈ License**](#οΈ-license)
+
#### A Triton library of Symmetric Memory operators and examples.
+This repository aims to be a cookbook for developing distributed AI models using Triton and PyTorch's symmetric memory capabilities.
+
+This is NOT intended to be a "framework" or "library" - it is intended to provide some high-performance Triton implementations with in-kernel communication for developers to hack on :) Please copy-paste and fork as you desire.
+
-This repository aims to simplify the process of developing distributed AI models using Triton and PyTorch's symmetric memory capabilities. Our initial kernels are adapted from the [Symmetric Memory Recipes](https://github.com/yifuwang/symm-mem-recipes) by Yifu Wang.
+In additional to that, it includes a set of benchmarks to help researchers and developers explore and evaluate their implmentations.
-## Examples
-TBD
+Our initial kernels are adapted from the [Symmetric Memory Recipes](https://github.com/yifuwang/symm-mem-recipes) by Yifu Wang.
-## Requirements
-Kraken requires:
-* Triton >= 3.3.0
-* PyTorch >= 2.6.0
-* Python >= 3.10
+## π― Features
+- Receipe for high-performance Triton implementations of `all_gather`, `all_reduce`, and `reduce_scatter`.
+- Comm-comp fused kernels such as `gemm_one_shot_all_reduce_fused` for increased efficiency.
+- A suite of benchmarks to measure and compare the performance of different comm + comp implementations.
+- PTX utilities for synchronization primitives not yet supported by Triton.
-## Installation
+## π Getting Started
+### Prerequisites
+- PyTorch (version 2.6.0 or higher)
+- Triton (version 3.3.0)
+- Python (version 3.10 or higher)
+- CUDA (version 12.4 or higher) Version must matche your PyTorch installaltion.
+
+### Installation
```bash
git clone https://github.com/meta-pytorch/kraken
cd kraken
pip install -e . -r requirements.txt
```
-## License
-Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository.
+## π» Usage
+Rather than a rigid framework, Kraken is a hands-on tutorial: developers can embed its techniques into xformers, FlashAttention, TorchInductor-generated kernelsβor any custom Triton code.
+
+There are two ways of using Kraken kernels:
+
+
+You can import and use the Kraken kernels in your own PyTorch projects. Here is an example of how to use the `one_shot_all_reduce` kernel:
+
+```python
+import torch
+import torch.distributed as dist
+import torch.distributed._symmetric_memory as symm_mem
+import kraken
+import os
+
+# setup distributed process group.
+local_rank = int(os.environ["LOCAL_RANK"])
+torch.cuda.set_device(f"cuda:{local_rank}")
+dist.init_process_group("nccl")
+
+# Create and initialize a symmetric memory tensor
+# See blog: https://dev-discuss.pytorch.org/t/pytorch-symmetricmemory-harnessing-nvlink-programmability-with-ease/279 for symmetric memory details.
+a_shared = symm_mem.empty(
+ (4096, 4096),
+ dtype=torch.bfloat16,
+ device=f"cuda:{local_rank}",
+ )
+symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
+a_shared = a_shared.normal_()
+
+# Call one_shot_all_reduce kernel from kraken.
+a = kraken.comm.one_shot_all_reduce(a_shared)
+```
+Remember to run with torchrun! Example torchrun command:
+```shell
+torchrun --nnodes 1 --nproc-per-node \
+ --rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python \
+ python3 example.py
+```
+
+Alternatively, you can build your own custom kernels by leveraging Kraken's low-level primitives. This allows you to create highly optimized kernels tailored to your specific needs. We provide PTX implementations of low-level primitives in `kraken._ptx_utils`.
+
+Here's an example of how to use `kraken._ptx_utils.symm_mem_sync` to synchronize blocks with matching `block_id` across participating devices in a custom kernel. This is often necessary before and after accessing symmetric memory tensors.
+
+```python
+import torch
+import torch.distributed as dist
+import torch.distributed._symmetric_memory as symm_mem
+
+import triton
+import triton.language as tl
+
+import kraken
+import os
+
+@triton.jit
+def custom_distributed_kernel(
+ a_shared_ptrs,
+ a_signal_pad_ptrs,
+ rank: tl.constexpr,
+ world_size: tl.constexpr,
+):
+ # Synchronizes blocks with matching block_id across participating devices.
+ # Ensures that all writes to a_shared from previous kernels across all devices
+ # are visible to the current kernel:
+ kraken._ptx_utils.symm_mem_sync(
+ a_signal_pad_ptrs,
+ None,
+ rank,
+ world_size,
+ hasPreviousMemAccess=False,
+ hasSubsequentMemAccess=True,
+ )
+ ... # access a_shared via a_shared_ptrs.
+
+# Create and initialize a symmetric memory tensor
+local_rank = int(os.environ["LOCAL_RANK"])
+torch.cuda.set_device(f"cuda:{local_rank}")
+dist.init_process_group("nccl")
+a_shared = symm_mem.empty((4096, 4096), dtype=torch.bfloat16, device=f"cuda:{local_rank}")
+symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
+
+# Define the grid for kernel launch. For simplicity, we use a single thread block.
+grid = (1,)
+
+# Call custom kernel
+custom_distributed_kernel[grid](
+ symm_mem_hdl.buffer_ptrs_dev,
+ symm_mem_hdl.signal_pad_ptrs_dev,
+ rank=symm_mem_hdl.rank,
+ world_size=symm_mem_hdl.world_size,
+)
+```
+
+
+## π Structure
+Kraken is organized for easy hacking of distributed Triton kernel:
+
+### Example Kernels
+#### `kraken.comm`
+contains communication kernels with fine-grained sychronizations.
+- `all_gather_w_progress`
+- `one_shot_all_reduce`
+- (coming soon) `two_shot_all_reduce`
+- (coming soon) `multimem_all_reduce`
+#### `kraken.fused`
+Fused communication/computation kernels.
+- All gather matmul: `all_gather_matmul`
+- Gemm all reduce: `gemm_one_shot_all_reduce_fused`
+- Gemm reduce scatter: `gemm_reduce_scatter`, `gemm_reduce_scatter_ce_persistent`
+- Reduce bias: `one_shot_all_reduce_bias`, `two_shot_all_reduce_bias`
+- Reduce bias rms_norm: `one_shot_all_reduce_bias_rms_norm`, `two_shot_all_reduce_bias_rms_norm`
+
+#### `kraken.quantized`
+(comming soon) Fused communication/computation kernels with quantization.
+
+
+### Inline PTX Utils
+`kraken._ptx_utils` provides inline ptx implementation of memory barrier synchorinzations that are not natively supported by triton.
+
+
+
+### Benchmarks
+Kraken includes a set of benchmarks in `benchmarks/` to evaluate the performance of its kernels. You can run them as follows:
+
+```bash
+torchrun --nnodes 1 --nproc-per-node \
+--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python python3 \
+benchmark/benchmark_all_reduce.py
+# ... and so on for other benchmarks
+```
+
+Run with `--help` to see configurable benchmark arguments for setting backends, dtype, shape etc. to profile.
+```bash
+python benchmark/benchmark_all_reduce.py --help
+```
+
+
+## π€ Contributing
+Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more details on how to contribute to the project.
+
+## βοΈ License
+Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository.
\ No newline at end of file
diff --git a/benchmark/benchmark_all_gather_matmul.py b/benchmark/benchmark_all_gather_matmul.py
index ccf70ee..b7fb1bd 100644
--- a/benchmark/benchmark_all_gather_matmul.py
+++ b/benchmark/benchmark_all_gather_matmul.py
@@ -3,7 +3,6 @@
import csv
from dataclasses import asdict, dataclass
import functools
-import itertools
import os
import sys
@@ -63,15 +62,10 @@ def asdict(self):
def generate_experiment_configs(
dtype: torch.dtype,
- M: list[int],
- N: list[int],
- K: list[int],
+ shapes: list[tuple[int, int, int]],
backends: list[str],
device: torch.device,
) -> list[ExperimentConfig]:
- # Generate cross config shapes from M, N, K lists
- shapes = list(itertools.product(M, N, K))
-
all_configs = []
for shape in shapes:
all_configs.append(
@@ -93,7 +87,7 @@ def get_single_backend_fn(backend: str):
if backend == "torch_symm_mem":
return torch_symm_mem_ag_mm
if backend == "triton":
- return kraken.all_gather.all_gather_matmul
+ return kraken.fused.all_gather_matmul
raise NotImplementedError(backend)
@@ -176,9 +170,7 @@ def main(args):
torch.manual_seed(42 + local_rank)
results = []
- configs = generate_experiment_configs(
- args.dtype, args.M, args.N, args.K, args.backend, device
- )
+ configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
for config in configs:
results.append(
Experiment(
@@ -196,7 +188,7 @@ def shape_input_type(s):
M, N, K = map(int, s.split(","))
return M, N, K
except Exception as e:
- raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
+ raise argparse.ArgumentTypeError("Shape must be M, N, K") from e
if __name__ == "__main__":
@@ -228,27 +220,15 @@ def shape_input_type(s):
)
parser.add_argument(
- "-M",
- type=shape_input_type,
- nargs="+",
- default=[2**x for x in range(7, 11)],
- help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
- )
-
- parser.add_argument(
- "-N",
+ "--shape",
type=shape_input_type,
nargs="+",
- default=[6656],
- help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
- )
-
- parser.add_argument(
- "-K",
- type=shape_input_type,
- nargs="+",
- default=[2**x for x in range(12, 15)],
- help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
+ default=[
+ (m, 6656, k)
+ for m in [2**x for x in range(7, 11)]
+ for k in [2**x for x in range(12, 16)]
+ ],
+ help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
)
parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")
diff --git a/benchmark/benchmark_all_reduce.py b/benchmark/benchmark_all_reduce.py
index c240263..cd91285 100644
--- a/benchmark/benchmark_all_reduce.py
+++ b/benchmark/benchmark_all_reduce.py
@@ -114,7 +114,7 @@ def get_single_backend_fn(backend: str):
if backend == "dist_2shot":
return symm_mem_two_shot_all_reduce
if backend == "triton_1shot":
- return kraken.all_reduce.one_shot_all_reduce
+ return kraken.comm.one_shot_all_reduce
if backend == "nccl":
return nccl_ring
raise NotImplementedError(backend)
diff --git a/benchmark/benchmark_all_reduce_bias.py b/benchmark/benchmark_all_reduce_bias.py
index 88f8a43..fd033f2 100644
--- a/benchmark/benchmark_all_reduce_bias.py
+++ b/benchmark/benchmark_all_reduce_bias.py
@@ -7,18 +7,15 @@
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
+import kraken
from kraken import _logging as log
-from kraken.all_reduce_fusion import (
- one_shot_all_reduce_bias,
- two_shot_all_reduce_bias,
-)
def one_shot_all_reduce_bias(
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
) -> torch.Tensor:
y = torch.empty_like(x)
- one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
+ kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return y
@@ -26,7 +23,7 @@ def two_shot_all_reduce_bias(
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
) -> torch.Tensor:
y = torch.empty_like(x)
- two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
+ kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
return y
diff --git a/benchmark/benchmark_all_reduce_bias_rms_norm.py b/benchmark/benchmark_all_reduce_bias_rms_norm.py
index 93b54c0..c8c5391 100644
--- a/benchmark/benchmark_all_reduce_bias_rms_norm.py
+++ b/benchmark/benchmark_all_reduce_bias_rms_norm.py
@@ -7,44 +7,42 @@
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
+import kraken
from kraken import _logging as log
-from kraken.all_reduce_fusion import (
- rms_norm,
- one_shot_all_reduce_bias,
- one_shot_all_reduce_bias_rms_norm,
- two_shot_all_reduce_bias,
- two_shot_all_reduce_bias_rms_norm,
-)
def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
- one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
+ kraken.fused.one_shot_all_reduce_bias_rms_norm(
+ symm_mem_input, x, bias, rms_weight, y
+ )
return y
def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
- one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
- return rms_norm(y, rms_weight)
+ kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
+ return kraken.fused.rms_norm(y, rms_weight)
def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
- two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
+ kraken.fused.two_shot_all_reduce_bias_rms_norm(
+ symm_mem_input, x, bias, rms_weight, y
+ )
return y
def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
y = torch.empty_like(x)
- two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
- return rms_norm(y, rms_weight)
+ kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
+ return kraken.fused.rms_norm(y, rms_weight)
def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight):
dist.all_reduce(x)
y = x + bias
- return rms_norm(y, rms_weight)
+ return kraken.fused.rms_norm(y, rms_weight)
def create_benchmarks(b, t, d_size, device, dtype):
diff --git a/benchmark/benchmark_matmul_reduce_scatter.py b/benchmark/benchmark_matmul_reduce_scatter.py
index e8179da..119824f 100644
--- a/benchmark/benchmark_matmul_reduce_scatter.py
+++ b/benchmark/benchmark_matmul_reduce_scatter.py
@@ -1,16 +1,15 @@
import argparse
+from collections import defaultdict
import csv
+from dataclasses import asdict, dataclass
import functools
-import itertools
import os
import sys
-from collections import defaultdict
-from dataclasses import asdict, dataclass
+from tabulate import tabulate
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
-from tabulate import tabulate
# Add the kraken directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -20,18 +19,22 @@
def torch_symm_mem_gemm_rs(a, b):
- gemm_rs_output = torch.ops.symm_mem.fused_matmul_reduce_scatter(
+ return torch.ops.symm_mem.fused_matmul_reduce_scatter(
a, b, "sum", scatter_dim=0, group_name=dist.group.WORLD.group_name
)
- return gemm_rs_output
+
def nccl_mem_gemm_rs(a, b):
- from torch.distributed._functional_collectives import reduce_scatter_tensor, wait_tensor
+ from torch.distributed._functional_collectives import (
+ reduce_scatter_tensor,
+ wait_tensor,
+ )
gemm_output = torch.matmul(a, b)
- rs_o = reduce_scatter_tensor(gemm_output, "sum", scatter_dim=0, group=dist.group.WORLD)
- gemm_rs_output = wait_tensor(rs_o)
- return gemm_rs_output
+ rs_o = reduce_scatter_tensor(
+ gemm_output, "sum", scatter_dim=0, group=dist.group.WORLD
+ )
+ return wait_tensor(rs_o)
@dataclass(frozen=True)
@@ -64,15 +67,10 @@ def asdict(self):
def generate_experiment_configs(
dtype: torch.dtype,
- M: list[int],
- N: list[int],
- K: list[int],
+ shapes: list[tuple[int, int, int]],
backends: list[str],
device: torch.device,
) -> list[ExperimentConfig]:
- # Generate cross config shapes from M, N, K lists
- shapes = list(itertools.product(M, N, K))
-
all_configs = []
for shape in shapes:
all_configs.append(
@@ -94,7 +92,7 @@ def get_single_backend_fn(backend: str):
if backend == "torch_symm_mem":
return torch_symm_mem_gemm_rs
if backend == "triton":
- return kraken.reduce_scatter_fusion.gemm_reduce_scatter
+ return kraken.fused.gemm_reduce_scatter
raise NotImplementedError(backend)
@@ -177,9 +175,7 @@ def main(args):
torch.manual_seed(42 + local_rank)
results = []
- configs = generate_experiment_configs(
- args.dtype, args.M, args.N, args.K, args.backend, device
- )
+ configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
for config in configs:
results.append(
Experiment(
@@ -197,7 +193,7 @@ def shape_input_type(s):
M, N, K = map(int, s.split(","))
return M, N, K
except Exception as e:
- raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
+ raise argparse.ArgumentTypeError("Shape must be M, N, K") from e
if __name__ == "__main__":
@@ -229,27 +225,15 @@ def shape_input_type(s):
)
parser.add_argument(
- "-M",
+ "--shape",
type=shape_input_type,
nargs="+",
- default=[2**x for x in range(7, 11)],
- help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
- )
-
- parser.add_argument(
- "-N",
- type=shape_input_type,
- nargs="+",
- default=[6656],
- help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
- )
-
- parser.add_argument(
- "-K",
- type=shape_input_type,
- nargs="+",
- default=[2**x for x in range(12, 16)],
- help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
+ default=[
+ (m, 6656, k)
+ for m in [2**x for x in range(7, 11)]
+ for k in [2**x for x in range(12, 16)]
+ ],
+ help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
)
parser.add_argument("-dtype", type=str, help="dtype", default="float32")
diff --git a/kraken/__init__.py b/kraken/__init__.py
index dffc887..ec4d3e4 100644
--- a/kraken/__init__.py
+++ b/kraken/__init__.py
@@ -1,3 +1,13 @@
-from . import _logging, all_gather, all_reduce, all_reduce_fusion, reduce_scatter_fusion
+from . import (
+ _logging,
+ _ptx_utils,
+ comm,
+ fused,
+)
-__all__ = ["_logging", "all_gather", "all_reduce", "all_reduce_fusion", "reduce_scatter_fusion"]
+__all__ = [
+ "_logging",
+ "_ptx_utils",
+ "comm",
+ "fused",
+]
diff --git a/kraken/_ptx_utils/__init__.py b/kraken/_ptx_utils/__init__.py
index 6d9493a..bf73a21 100644
--- a/kraken/_ptx_utils/__init__.py
+++ b/kraken/_ptx_utils/__init__.py
@@ -1,15 +1,19 @@
from .gmem_barrier_arrive_wait import arrive_gmem_barrier, wait_gmem_barrier
from .symm_mem_barrier import (
_get_flat_tid as get_flat_tid,
+)
+from .symm_mem_barrier import (
_send_signal as send_signal,
+)
+from .symm_mem_barrier import (
symm_mem_sync as symm_mem_sync,
)
__all__ = [
"arrive_gmem_barrier",
- "symm_mem_sync",
- "wait_gmem_barrier",
"get_flat_tid",
"send_signal",
+ "symm_mem_sync",
+ "wait_gmem_barrier",
]
# Avoid ptx_utils when possible
diff --git a/kraken/_ptx_utils/symm_mem_barrier.py b/kraken/_ptx_utils/symm_mem_barrier.py
index b42844b..a04f327 100644
--- a/kraken/_ptx_utils/symm_mem_barrier.py
+++ b/kraken/_ptx_utils/symm_mem_barrier.py
@@ -101,7 +101,7 @@ def symm_mem_sync(
rank: tl.constexpr,
world_size: tl.constexpr,
hasPreviousMemAccess: tl.constexpr = False,
- hasSubsequenceMemAccess: tl.constexpr = False,
+ hasSubsequentMemAccess: tl.constexpr = False,
):
"""
Synchronizes blocks with matching block_id across participating devices.
@@ -112,17 +112,17 @@ def symm_mem_sync(
Pattern 0: Ensures that all writes to symm_mem buffers from previous
kernels across all devices are visible to the current kernel:
- symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequenceMemAccess=True)
+ symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequentMemAccess=True)
Pattern 1: Ensures that all writes to symm_mem buffers from the current
block are visible to all remote blocks with matching blockIdx:
- symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequenceMemAccess=True)
+ symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=True)
Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe
for writing by subsequent kernels across all devices.
- symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequenceMemAccess=False)
+ symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=False)
CUDA graph friendliness:
@@ -152,7 +152,7 @@ def symm_mem_sync(
if flat_tid < world_size:
_send_signal(send_addrs, "release" if hasPreviousMemAccess else "relaxed")
- _wait_signal(wait_addrs, "acquire" if hasSubsequenceMemAccess else "relaxed")
+ _wait_signal(wait_addrs, "acquire" if hasSubsequentMemAccess else "relaxed")
- if hasSubsequenceMemAccess:
+ if hasSubsequentMemAccess:
tl.debug_barrier()
diff --git a/kraken/all_gather/__init__.py b/kraken/all_gather/__init__.py
deleted file mode 100644
index 605884e..0000000
--- a/kraken/all_gather/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .all_gather_matmul import (
- all_gather_matmul as all_gather_matmul,
-)
-
-__all__ = ["all_gather_matmul"]
diff --git a/kraken/all_gather/copy_engine_all_gather.py b/kraken/all_gather/copy_engine_all_gather.py
deleted file mode 100644
index e89b01a..0000000
--- a/kraken/all_gather/copy_engine_all_gather.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import torch
-import torch.distributed as dist
-import torch.distributed._symmetric_memory as symm_mem
-
-
-def copy_engine_all_gather_w_progress(
- output: torch.Tensor,
- inp: torch.Tensor, # Must be symmetric tensor
- progress: torch.Tensor,
- splits_per_rank: int,
- backend_stream: torch.cuda.Stream | None = None,
-) -> torch.cuda.Stream:
- backend_stream = symm_mem._get_backend_stream(priority=-1)
- assert inp.is_contiguous()
-
- symm_mem_hdl = symm_mem.rendezvous(inp, group=dist.group.WORLD)
- assert symm_mem_hdl is not None
-
- rank = symm_mem_hdl.rank
- world_size = symm_mem_hdl.world_size
-
- assert inp.numel() % splits_per_rank == 0
- assert progress.numel() >= world_size * splits_per_rank
-
- output_shape = list(inp.shape)
- output_shape[0] *= world_size
- assert list(output.shape) == output_shape, (list(output.shape), output_shape)
-
- chunks = output.chunk(world_size * splits_per_rank)
-
- symm_mem_hdl.barrier()
- backend_stream.wait_stream(torch.cuda.current_stream())
-
- with torch.cuda.stream(backend_stream):
- for step in range(world_size):
- src_rank = (rank + step + 1) % world_size
- for split_id in range(splits_per_rank):
- src_buf = symm_mem_hdl.get_buffer(
- src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
- )
- chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
- # cuStreamWriteValue32 issues a system level fence before the write
- symm_mem_hdl.stream_write_value32(
- progress,
- offset=src_rank * splits_per_rank + split_id,
- val=1,
- )
- symm_mem_hdl.barrier()
-
- return backend_stream
diff --git a/kraken/all_reduce/__init__.py b/kraken/all_reduce/__init__.py
deleted file mode 100644
index 0c48c58..0000000
--- a/kraken/all_reduce/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .one_shot_all_reduce import (
- one_shot_all_reduce as one_shot_all_reduce,
-)
-
-__all__ = ["one_shot_all_reduce"]
diff --git a/kraken/comm/__init__.py b/kraken/comm/__init__.py
new file mode 100644
index 0000000..58fb6c6
--- /dev/null
+++ b/kraken/comm/__init__.py
@@ -0,0 +1,13 @@
+from .copy_engine_all_gather import (
+ _copy_engine_all_gather_w_progress,
+ all_gather_w_progress,
+)
+from .one_shot_all_reduce import (
+ one_shot_all_reduce as one_shot_all_reduce,
+)
+
+__all__ = [
+ "_copy_engine_all_gather_w_progress",
+ "all_gather_w_progress",
+ "one_shot_all_reduce",
+]
diff --git a/kraken/comm/copy_engine_all_gather.py b/kraken/comm/copy_engine_all_gather.py
new file mode 100644
index 0000000..d817520
--- /dev/null
+++ b/kraken/comm/copy_engine_all_gather.py
@@ -0,0 +1,126 @@
+import torch
+import torch.distributed as dist
+import torch.distributed._symmetric_memory as symm_mem
+
+
+def _copy_engine_all_gather_w_progress(
+ output: torch.Tensor,
+ inp: torch.Tensor, # Must be symmetric tensor
+ progress: torch.Tensor,
+ splits_per_rank: int,
+ backend_stream: torch.cuda.Stream | None = None,
+) -> torch.cuda.Stream:
+ backend_stream = symm_mem._get_backend_stream(priority=-1)
+ assert inp.is_contiguous()
+
+ symm_mem_hdl = symm_mem.rendezvous(inp, group=dist.group.WORLD)
+ assert symm_mem_hdl is not None
+
+ rank = symm_mem_hdl.rank
+ world_size = symm_mem_hdl.world_size
+
+ assert inp.numel() % splits_per_rank == 0
+ assert progress.numel() >= world_size * splits_per_rank
+
+ output_shape = list(inp.shape)
+ output_shape[0] *= world_size
+ assert list(output.shape) == output_shape, (list(output.shape), output_shape)
+
+ # Split the output tensor into chunks for each rank and split.
+ chunks = output.chunk(world_size * splits_per_rank)
+
+ # Synchronize all ranks before starting the copy operations.
+ # This ensures any previous operations on the symmetric memory tensor are completed.
+ symm_mem_hdl.barrier()
+ backend_stream.wait_stream(torch.cuda.current_stream())
+
+ # Perform the all-gather operation on the backend stream.
+ with torch.cuda.stream(backend_stream):
+ # Iterate through source rank and splits of the source rank.
+ for step in range(world_size):
+ src_rank = (rank + step + 1) % world_size
+ for split_id in range(splits_per_rank):
+ src_buf = symm_mem_hdl.get_buffer(
+ src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
+ )
+ # Copy data from the source buffer to the corresponding output chunk using copy engine.
+ chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
+ # Signal the completion of the copy for this chunk in progress tensor.
+ # cuStreamWriteValue32 issues a system level fence before the write
+ symm_mem_hdl.stream_write_value32(
+ progress,
+ offset=src_rank * splits_per_rank + split_id,
+ val=1,
+ )
+
+ # Synchronize all ranks after all copy operations are issued.
+ # This ensures all copy operations are completed before proceeding.
+ symm_mem_hdl.barrier()
+
+ return backend_stream
+
+
+def all_gather_w_progress(
+ a_shared: torch.Tensor,
+ a_out: torch.Tensor | None = None,
+ progress: torch.Tensor | None = None,
+ **kwargs,
+) -> torch.Tensor:
+ """
+ Performs an all-gather operation using the copy engine and tracks progress.
+
+ This function gathers data from all ranks into a single output tensor. It uses
+ the copy engine for the data transfer and a progress tensor to signal the
+ completion of each chunk copy in the progress tensor.
+ The operation is performed on a backend CUDA stream.
+
+ Args:
+ a_shared (torch.Tensor): The input tensor, which must be a symmetric tensor.
+ Each rank provides its shard of the data in this tensor.
+ a_out (torch.Tensor, optional): The output tensor to store the gathered data.
+ progress (torch.Tensor, optional): A tensor to track the progress of the copy
+ operations. Its size should be at least `world_size * splits_per_rank`.
+ Initially, all elements should be zero. After a chunk is copied,
+ the corresponding element is set to 1.
+ splits_per_rank (int): The number of splits (chunks) per rank.
+ backend_stream (torch.cuda.Stream, optional): A background CUDA stream for
+ the copy engine operations. If not provided, a new stream is created.
+
+ Returns:
+ torch.Tensor: The output tensor containing the gathered data from all ranks.
+ """
+ configs = {
+ "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
+ }
+
+ symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
+
+ a_shape = list(a_shared.shape)
+ a_shape[0] *= symm_mem_hdl.world_size
+
+ configs["RANK"] = symm_mem_hdl.rank
+ configs["WORLD_SIZE"] = symm_mem_hdl.world_size
+
+ configs["COMM_BLOCK_SIZE_M"] = (
+ a_shape[0] // configs["WORLD_SIZE"] // configs["SPLITS_PER_RANK"]
+ )
+
+ if a_out is None:
+ a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device)
+
+ if progress is None:
+ progress = torch.zeros(
+ symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"],
+ dtype=torch.uint32,
+ device=a_shared.device,
+ )
+ else:
+ progress.fill_(0) # Reset progress to 0.
+
+ backend_stream = _copy_engine_all_gather_w_progress(
+ a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
+ )
+
+ torch.cuda.current_stream().wait_stream(backend_stream)
+
+ return a_out
diff --git a/kraken/all_reduce/one_shot_all_reduce.py b/kraken/comm/one_shot_all_reduce.py
similarity index 53%
rename from kraken/all_reduce/one_shot_all_reduce.py
rename to kraken/comm/one_shot_all_reduce.py
index 6afed6f..7080c21 100644
--- a/kraken/all_reduce/one_shot_all_reduce.py
+++ b/kraken/comm/one_shot_all_reduce.py
@@ -1,3 +1,10 @@
+"""
+This module implements a Triton kernel for one-shot all-reduce.
+This kernel performs an all-reduce operation on a Torch symmetric memory tensor distributed across
+multiple devices. According to benchmark results, one-shot all reduce outperforms NCCL ring reduce
+for small message sizes (<~400KB on a 8xH100 system with NVSwitch).
+"""
+
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
@@ -17,19 +24,22 @@ def one_shot_all_reduce_kernel(
world_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
+ # Synchronize blocks with matching block_id across all participating devices before starting.
+ # This ensures that all previous memory operations are visible.
ptx_utils.symm_mem_sync(
- signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
+ signal_pad_ptrs, None, rank, world_size, hasSubsequentMemAccess=True
)
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
while block_start < numel:
- # Each thread processes 128 bits.
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < numel
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)
+
+ # Iteratively load from each rank's buffer and accumulate. `static_range` unrolls the loop at compile time, enabling efficient iteration over `buf_tuple`.
for i in tl.static_range(world_size):
buffer_rank = buf_tuple[i]
x = tl.load(buffer_rank + offsets, mask=mask)
@@ -37,12 +47,29 @@ def one_shot_all_reduce_kernel(
tl.store(output_ptr + offsets, acc, mask=mask)
block_start += tl.num_programs(axis=0) * BLOCK_SIZE
+ # Synchronize all participating devices after the reduction is complete.
+ # Subsequent kernel cannot overwrite the symmetric memory buffer until all devices reach this point.
ptx_utils.symm_mem_sync(
signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
)
def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor:
+ """
+ Perform a one-shot all-reduce operation using symmetric memory.
+
+ output = all_reduce(input)
+
+ Args:
+ tensor (torch.Tensor): The input tensor to be reduced. Must be of dtype torch.bfloat16 and 128-bit aligned.
+ **kwargs: Additional keyword arguments for kernel configuration:
+ max_num_blocks (int, optional): The maximum number of blocks to launch.
+ num_warps (int, optional): The number of warps per block.
+ BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel.
+
+ Returns:
+ torch.Tensor: The output tensor containing the reduced result.
+ """
config = {
"max_num_blocks": kwargs.get("max_num_blocks", 24),
"num_warps": kwargs.get("num_warps", 32),
@@ -62,12 +89,16 @@ def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor:
symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD)
output = torch.empty_like(tensor)
+ # Get the buffer pointers for each rank from the symmetric memory handle, and pass them as a tuple to the triton kernel.
buf_list = [
symm_mem_hdl.get_buffer(i, tuple(tensor.shape), tensor.dtype)
for i in range(symm_mem_hdl.world_size)
]
buf_tuple = tuple(buf_list)
+ # symm_mem_hdl.signal_pad_ptrs_dev: An array of pointers pointing to signal_pads for each rank.
+ # A signal pad is a memory region used for synchronization between devices.
+ # `symm_mem_sync` kernel uses these signal pads to implement a cross-device barrier to ensure memory visibility of symmetric memory tensors.
one_shot_all_reduce_kernel[(num_blocks, 1, 1)](
buf_tuple,
symm_mem_hdl.signal_pad_ptrs_dev,
diff --git a/kraken/all_reduce_fusion/__init__.py b/kraken/fused/__init__.py
similarity index 70%
rename from kraken/all_reduce_fusion/__init__.py
rename to kraken/fused/__init__.py
index 22b8e4d..1dc9068 100644
--- a/kraken/all_reduce_fusion/__init__.py
+++ b/kraken/fused/__init__.py
@@ -1,21 +1,27 @@
-from .rms_norm import rms_norm
+from .all_gather_matmul import all_gather_matmul
from .gemm_one_shot_all_reduce_fused import (
gemm_one_shot_all_reduce as gemm_one_shot_all_reduce_fused,
)
+from .gemm_reduce_scatter_ce_persistent import gemm_reduce_scatter_ce_persistent
+from .gemm_reduce_scatter_fused import gemm_reduce_scatter
from .one_shot_all_reduce_bias import one_shot_all_reduce_bias
from .one_shot_all_reduce_bias_rms_norm import (
one_shot_all_reduce_bias_rms_norm,
)
+from .rms_norm import rms_norm
from .two_shot_all_reduce_bias import two_shot_all_reduce_bias
from .two_shot_all_reduce_bias_rms_norm import (
two_shot_all_reduce_bias_rms_norm,
)
__all__ = [
- "rms_norm",
+ "all_gather_matmul",
"gemm_one_shot_all_reduce_fused",
+ "gemm_reduce_scatter",
+ "gemm_reduce_scatter_ce_persistent",
"one_shot_all_reduce_bias",
"one_shot_all_reduce_bias_rms_norm",
+ "rms_norm",
"two_shot_all_reduce_bias",
"two_shot_all_reduce_bias_rms_norm",
]
diff --git a/kraken/all_gather/all_gather_matmul.py b/kraken/fused/all_gather_matmul.py
similarity index 76%
rename from kraken/all_gather/all_gather_matmul.py
rename to kraken/fused/all_gather_matmul.py
index edfc566..19189bd 100644
--- a/kraken/all_gather/all_gather_matmul.py
+++ b/kraken/fused/all_gather_matmul.py
@@ -6,7 +6,7 @@
import triton.tools.experimental_descriptor
from .._ptx_utils import wait_gmem_barrier
-from .copy_engine_all_gather import copy_engine_all_gather_w_progress
+from ..comm import _copy_engine_all_gather_w_progress
def _matmul_launch_metadata(grid, kernel, args):
@@ -43,7 +43,11 @@ def _matmul_kernel_tma_persistent_w_progress(
NUM_SMS: tl.constexpr,
):
"""
- Slightly modified from the sm90 tma persistent Triton tutorial.
+ Persistent Triton kernel for matrix multiplication with progress waiting.
+
+ This kernel performs matrix multiplication (`C = A @ B`) in a persistent manner.
+ It waits for chunks of the `A` matrix to be gathered from other ranks by
+ monitoring a `progress_ptr` before consuming them.
"""
dtype = tl.float8e4nv if FP8_OUTPUT else tl.bfloat16
@@ -237,6 +241,44 @@ def all_gather_matmul(
progress: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Performs a fused all-gather and matrix multiplication operation.
+
+ This function first performs an all-gather operation on the `a_shared` tensor
+ to construct the full `A` matrix. It then performs a matrix multiplication
+ `C = A @ B`. The all-gather is performed by the copy engine on a separate
+ stream, and the matrix multiplication is performed by a persistent Triton
+ kernel that waits for the data to be gathered.
+
+ Args:
+ a_shared (torch.Tensor): The local shard of the `A` matrix. This must
+ be a symmetric tensor.
+ b (torch.Tensor): The `B` matrix.
+ a_out (torch.Tensor | None, optional): The output tensor for the
+ all-gathered `A` matrix. If None, a new tensor is created.
+ progress (torch.Tensor | None, optional): A tensor for tracking the
+ progress of the all-gather operation. If None, a new tensor is
+ created.
+ **kwargs: Additional keyword arguments for kernel configuration:
+ splits_per_rank (int, optional): The number of splits for the
+ all-gather operation. Defaults to 1.
+ block_size_m (int, optional): The block size for the M dimension.
+ Defaults to 128.
+ block_size_n (int, optional): The block size for the N dimension.
+ Defaults to 256.
+ block_size_k (int, optional): The block size for the K dimension.
+ Defaults to 64.
+ group_size_m (int, optional): The group size for the M dimension.
+ Defaults to 4.
+ num_stages (int, optional): The number of stages for the matmul
+ kernel. Defaults to 3.
+ num_warps (int, optional): The number of warps for the matmul
+ kernel. Defaults to 8.
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing the all-gathered
+ `A` matrix and the result of the matrix multiplication `C`.
+ """
configs = {
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
"BLOCK_SIZE_M": kwargs.get("block_size_m", 128),
@@ -285,10 +327,12 @@ def all_gather_matmul(
else:
progress.fill_(0) # Reset progress to 0.
- backend_stream = copy_engine_all_gather_w_progress(
+ # Perform all-gather using the copy engine on a backend stream.
+ backend_stream = _copy_engine_all_gather_w_progress(
a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
)
+ # Perform matrix multiplication on gathered a, which waits for signal of completion for each chunk of a.
c = _matmul_w_progress(a_out, a_shared, b, progress, configs)
torch.cuda.current_stream().wait_stream(backend_stream)
diff --git a/kraken/all_reduce_fusion/gemm_one_shot_all_reduce_fused.py b/kraken/fused/gemm_one_shot_all_reduce_fused.py
similarity index 96%
rename from kraken/all_reduce_fusion/gemm_one_shot_all_reduce_fused.py
rename to kraken/fused/gemm_one_shot_all_reduce_fused.py
index 68395c1..56724b4 100644
--- a/kraken/all_reduce_fusion/gemm_one_shot_all_reduce_fused.py
+++ b/kraken/fused/gemm_one_shot_all_reduce_fused.py
@@ -78,7 +78,7 @@ def gemm_one_shot_all_reduce_kernel(
# Synchronize before all-reduce
ptx_utils.symm_mem_sync(
- signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
+ signal_pad_ptrs, None, rank, world_size, hasSubsequentMemAccess=True
)
# All-reduce: sum results from all ranks
@@ -115,9 +115,9 @@ def gemm_one_shot_all_reduce(
Output matrix of shape (M, N) containing the all-reduced result
"""
- assert (
- a.shape[1] == b.shape[0]
- ), "Inner dimensions must match for matrix multiplication"
+ assert a.shape[1] == b.shape[0], (
+ "Inner dimensions must match for matrix multiplication"
+ )
M, K = a.shape
K, N = b.shape
diff --git a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py b/kraken/fused/gemm_reduce_scatter_ce_persistent.py
similarity index 99%
rename from kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py
rename to kraken/fused/gemm_reduce_scatter_ce_persistent.py
index 97d358b..413f529 100644
--- a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py
+++ b/kraken/fused/gemm_reduce_scatter_ce_persistent.py
@@ -1,10 +1,9 @@
+from cuda.bindings import driver
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
-
import triton
import triton.language as tl
-from cuda.bindings import driver
from .._ptx_utils import get_flat_tid, send_signal
@@ -189,7 +188,7 @@ def gemm_producer_w_progress(
):
M, K = a.shape
Kb, N = b.shape
- assert K == Kb, "Inner dimensions must match for matrix multiplication"
+ assert Kb == K, "Inner dimensions must match for matrix multiplication"
assert a.dtype == b.dtype, "Input dtypes must match"
bT = b.T
diff --git a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused.py b/kraken/fused/gemm_reduce_scatter_fused.py
similarity index 95%
rename from kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused.py
rename to kraken/fused/gemm_reduce_scatter_fused.py
index da88f47..fca06c1 100644
--- a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused.py
+++ b/kraken/fused/gemm_reduce_scatter_fused.py
@@ -61,7 +61,7 @@ def gemm_reduce_scatter_kernel(
# GEMM Computation
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ for k in range(tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), other=0.0)
accumulator = tl.dot(a, b, accumulator)
@@ -91,7 +91,7 @@ def gemm_reduce_scatter_kernel(
# synchronize
ptx_utils.symm_mem_sync(
- signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
+ signal_pad_ptrs, None, rank, world_size, hasSubsequentMemAccess=True
)
# Reduce Scatter logic: For each tile in the rank's assigned row slice (along M),
@@ -161,9 +161,9 @@ def gemm_reduce_scatter(a: torch.Tensor, b: torch.Tensor, **kwargs) -> torch.Ten
Output matrix of shape (M / world_size, N) containing the reduce-scattered result.
"""
- assert (
- a.shape[1] == b.shape[0]
- ), "Inner dimensions must match for matrix multiplication"
+ assert a.shape[1] == b.shape[0], (
+ "Inner dimensions must match for matrix multiplication"
+ )
M, K = a.shape
_, N = b.shape
@@ -171,9 +171,9 @@ def gemm_reduce_scatter(a: torch.Tensor, b: torch.Tensor, **kwargs) -> torch.Ten
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)
- assert (
- M % world_size == 0
- ), f"M dimension ({M}) must be divisible by world_size ({world_size})"
+ assert M % world_size == 0, (
+ f"M dimension ({M}) must be divisible by world_size ({world_size})"
+ )
# Configuration stuff
BLOCK_SIZE_M = kwargs.get("BLOCK_SIZE_M", 64)
diff --git a/kraken/all_reduce_fusion/one_shot_all_reduce_bias.py b/kraken/fused/one_shot_all_reduce_bias.py
similarity index 99%
rename from kraken/all_reduce_fusion/one_shot_all_reduce_bias.py
rename to kraken/fused/one_shot_all_reduce_bias.py
index 5fcb83c..45f78f4 100644
--- a/kraken/all_reduce_fusion/one_shot_all_reduce_bias.py
+++ b/kraken/fused/one_shot_all_reduce_bias.py
@@ -70,7 +70,7 @@ def one_shot_all_reduce_bias_kernel(
rank,
world_size,
hasPreviousMemAccess=True,
- hasSubsequenceMemAccess=True,
+ hasSubsequentMemAccess=True,
)
block_start = pid * BLOCK_SIZE
diff --git a/kraken/all_reduce_fusion/one_shot_all_reduce_bias_rms_norm.py b/kraken/fused/one_shot_all_reduce_bias_rms_norm.py
similarity index 99%
rename from kraken/all_reduce_fusion/one_shot_all_reduce_bias_rms_norm.py
rename to kraken/fused/one_shot_all_reduce_bias_rms_norm.py
index c16c5e9..9e7c3e7 100644
--- a/kraken/all_reduce_fusion/one_shot_all_reduce_bias_rms_norm.py
+++ b/kraken/fused/one_shot_all_reduce_bias_rms_norm.py
@@ -66,7 +66,7 @@ def one_shot_all_reduce_bias_rms_norm_kernel(
rank,
world_size,
hasPreviousMemAccess=True,
- hasSubsequenceMemAccess=True,
+ hasSubsequentMemAccess=True,
)
# Allreduce + bias
diff --git a/kraken/all_reduce_fusion/rms_norm.py b/kraken/fused/rms_norm.py
similarity index 100%
rename from kraken/all_reduce_fusion/rms_norm.py
rename to kraken/fused/rms_norm.py
diff --git a/kraken/all_reduce_fusion/two_shot_all_reduce_bias.py b/kraken/fused/two_shot_all_reduce_bias.py
similarity index 98%
rename from kraken/all_reduce_fusion/two_shot_all_reduce_bias.py
rename to kraken/fused/two_shot_all_reduce_bias.py
index 05128f6..48ddf2f 100644
--- a/kraken/all_reduce_fusion/two_shot_all_reduce_bias.py
+++ b/kraken/fused/two_shot_all_reduce_bias.py
@@ -60,7 +60,7 @@ def two_shot_all_reduce_bias_kernel(
rank,
world_size,
hasPreviousMemAccess=True,
- hasSubsequenceMemAccess=True,
+ hasSubsequentMemAccess=True,
)
# Two-shot allreduce
@@ -95,7 +95,7 @@ def two_shot_all_reduce_bias_kernel(
rank,
world_size,
hasPreviousMemAccess=True,
- hasSubsequenceMemAccess=True,
+ hasSubsequentMemAccess=True,
)
# Copy the result from the symmetric memory buffer to the output.
diff --git a/kraken/all_reduce_fusion/two_shot_all_reduce_bias_rms_norm.py b/kraken/fused/two_shot_all_reduce_bias_rms_norm.py
similarity index 98%
rename from kraken/all_reduce_fusion/two_shot_all_reduce_bias_rms_norm.py
rename to kraken/fused/two_shot_all_reduce_bias_rms_norm.py
index c7e4226..f60d39e 100644
--- a/kraken/all_reduce_fusion/two_shot_all_reduce_bias_rms_norm.py
+++ b/kraken/fused/two_shot_all_reduce_bias_rms_norm.py
@@ -72,7 +72,7 @@ def two_shot_all_reduce_bias_rms_norm_kernel(
rank,
world_size,
hasPreviousMemAccess=True,
- hasSubsequenceMemAccess=True,
+ hasSubsequentMemAccess=True,
)
# Two shot allreduce
@@ -104,7 +104,7 @@ def two_shot_all_reduce_bias_rms_norm_kernel(
rank,
world_size,
hasPreviousMemAccess=True,
- hasSubsequenceMemAccess=True,
+ hasSubsequentMemAccess=True,
)
# The regular RMSNorm
diff --git a/kraken/reduce_scatter_fusion/__init__.py b/kraken/reduce_scatter_fusion/__init__.py
deleted file mode 100644
index 3ebae8b..0000000
--- a/kraken/reduce_scatter_fusion/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .gemm_reduce_scatter_ce_persistent import gemm_reduce_scatter_ce_persistent
-from .gemm_reduce_scatter_fused import gemm_reduce_scatter
-
-__all__ = ["gemm_reduce_scatter", "gemm_reduce_scatter_ce_persistent"]
diff --git a/requirements.txt b/requirements.txt
index 7d0fafa..321e4fe 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
# Saw a worse performance when using older triton version.
# This will ensure that people can reproduce the same performance.
-triton >= 3.3.0
+triton == 3.3.0
torch >= 2.6.0; python_version >= "3.10"
tyro
cuda-bindings
diff --git a/test/test_allgather.py b/test/test_allgather.py
new file mode 100644
index 0000000..40dc8a7
--- /dev/null
+++ b/test/test_allgather.py
@@ -0,0 +1,78 @@
+import os
+import sys
+
+import torch
+import torch.distributed as dist
+import torch.distributed._functional_collectives as fc
+import torch.distributed._symmetric_memory as symm_mem
+from torch.testing._internal.common_distributed import (
+ MultiProcessTestCase,
+ skip_if_lt_x_gpu,
+)
+from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
+ run_tests,
+)
+
+# Add the parent directory to the Python path
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+import kraken
+
+
+@instantiate_parametrized_tests
+class TritonAllGatherTest(MultiProcessTestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ self._spawn_processes()
+
+ @property
+ def world_size(self) -> int:
+ # world_size > 2 is needed to verify accumulation order
+ return torch.cuda.device_count()
+
+ @property
+ def device(self) -> torch.device:
+ return torch.device(f"cuda:{self.rank}")
+
+ def _init_process(self):
+ torch.cuda.set_device(self.device)
+ store = dist.FileStore(self.file_name, self.world_size)
+ dist.init_process_group(
+ backend="nccl",
+ world_size=self.world_size,
+ rank=self.rank,
+ store=store,
+ )
+ torch.manual_seed(42 + self.rank)
+
+ @skip_if_lt_x_gpu(4)
+ def test_all_gather_w_progress(self):
+ self._init_process()
+ group_name = dist.group.WORLD.group_name
+ a_shared = symm_mem.empty(
+ (1024, 1024),
+ dtype=torch.bfloat16,
+ device=self.device,
+ ).normal_()
+ symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name)
+
+ progress = torch.zeros(
+ symm_mem_hdl.world_size,
+ dtype=torch.uint32,
+ device=self.device,
+ )
+
+ golden_a = a_shared.clone()
+ a_gathered = fc.all_gather_tensor(golden_a, 0, "0")
+
+ a_out = kraken.comm.all_gather_w_progress(a_shared, progress=progress)
+
+ torch.testing.assert_close(a_out, a_gathered)
+ assert torch.all(progress != 0)
+
+ dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/test_allgather_matmul.py b/test/test_allgather_matmul.py
index fb7de0b..a203128 100644
--- a/test/test_allgather_matmul.py
+++ b/test/test_allgather_matmul.py
@@ -64,7 +64,7 @@ def test_all_gather_matmul(self):
).T.contiguous()
b = bT.T
- ag, c = kraken.all_gather.all_gather_matmul(a_shared, b)
+ ag, c = kraken.fused.all_gather_matmul(a_shared, b)
golden_a = a_shared.clone()
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
diff --git a/test/test_allreduce.py b/test/test_allreduce.py
index 0725f78..b3a30c7 100644
--- a/test/test_allreduce.py
+++ b/test/test_allreduce.py
@@ -57,7 +57,7 @@ def test_one_shot(self):
input_tensor = input_tensor.normal_()
symm_mem.rendezvous(input_tensor, group_name)
- result = kraken.all_reduce.one_shot_all_reduce(input_tensor)
+ result = kraken.comm.one_shot_all_reduce(input_tensor)
golden = input_tensor.clone()
dist.all_reduce(golden)
diff --git a/test/test_allreduce_bias.py b/test/test_allreduce_bias.py
index 4b311f6..0ee7683 100644
--- a/test/test_allreduce_bias.py
+++ b/test/test_allreduce_bias.py
@@ -12,7 +12,7 @@
run_tests,
)
-from kraken.all_reduce_fusion import (
+from kraken.fused import (
one_shot_all_reduce_bias,
two_shot_all_reduce_bias,
)
diff --git a/test/test_allreduce_bias_rms_norm.py b/test/test_allreduce_bias_rms_norm.py
index 139e237..8ab16df 100644
--- a/test/test_allreduce_bias_rms_norm.py
+++ b/test/test_allreduce_bias_rms_norm.py
@@ -12,9 +12,9 @@
run_tests,
)
-from kraken.all_reduce_fusion import (
- rms_norm,
+from kraken.fused import (
one_shot_all_reduce_bias_rms_norm,
+ rms_norm,
two_shot_all_reduce_bias_rms_norm,
)
@@ -74,9 +74,7 @@ def test_one_shot_bias_rms_norm(self):
bias = torch.randn(b, 5120, device=self.device, dtype=torch.bfloat16)
w = torch.randn(5120, device=self.device, dtype=torch.bfloat16)
y = torch.empty_like(input_tensor)
- one_shot_all_reduce_bias_rms_norm(
- symm_mem_buffer, input_tensor, bias, w, y
- )
+ one_shot_all_reduce_bias_rms_norm(symm_mem_buffer, input_tensor, bias, w, y)
baseline = self._nccl_all_reduce_bias_rms_norm(
input_tensor.clone(), w.clone(), bias.clone()
)
@@ -107,9 +105,7 @@ def test_two_shot_bias_rms_norm(self):
bias = torch.randn(b, 5120, device=self.device, dtype=torch.bfloat16)
w = torch.randn(5120, device=self.device, dtype=torch.bfloat16)
y = torch.empty_like(input_tensor)
- two_shot_all_reduce_bias_rms_norm(
- symm_mem_buffer, input_tensor, bias, w, y
- )
+ two_shot_all_reduce_bias_rms_norm(symm_mem_buffer, input_tensor, bias, w, y)
baseline = self._nccl_all_reduce_bias_rms_norm(
input_tensor.clone(), w.clone(), bias.clone()
)
diff --git a/test/test_gemm_allreduce.py b/test/test_gemm_allreduce.py
index 2a05699..1d848e0 100644
--- a/test/test_gemm_allreduce.py
+++ b/test/test_gemm_allreduce.py
@@ -1,10 +1,9 @@
+from datetime import timedelta
import os
import sys
-from datetime import timedelta
import torch
import torch.distributed as dist
-import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
@@ -59,7 +58,7 @@ def test_gemm_all_reduce(self):
b = torch.empty((K, N), dtype=torch.float32, device=self.device).normal_()
# calculate result for our fused kernel
- result = kraken.all_reduce_fusion.gemm_one_shot_all_reduce_fused(a, b)
+ result = kraken.fused.gemm_one_shot_all_reduce_fused(a, b)
# expected value
expected = torch.matmul(a, b)
@@ -79,7 +78,7 @@ def test_gemm_all_reduce_square(self):
b = torch.empty((K, N), dtype=torch.float32, device=self.device).normal_()
# calculate result for our fused kernel
- result = kraken.all_reduce_fusion.gemm_one_shot_all_reduce_fused(a, b)
+ result = kraken.fused.gemm_one_shot_all_reduce_fused(a, b)
# expected value
expected = torch.matmul(a, b)
@@ -98,19 +97,25 @@ def test_rank_specific_values_all_reduce(self):
# Each rank contributes (rank + 1) to the final sum
# This makes it easy to verify all-reduce worked correctly
rank_multiplier = self.rank + 1
- a = torch.ones((M, K), dtype=torch.float32, device=self.device) * rank_multiplier
+ a = (
+ torch.ones((M, K), dtype=torch.float32, device=self.device)
+ * rank_multiplier
+ )
b = torch.ones((K, N), dtype=torch.float32, device=self.device)
- result = kraken.all_reduce_fusion.gemm_one_shot_all_reduce_fused(a, b)
+ result = kraken.fused.gemm_one_shot_all_reduce_fused(a, b)
# Expected: sum of all rank contributions
# rank 0: 1*K, rank 1: 2*K, rank 2: 3*K, rank 3: 4*K
# Total = K * (1+2+3+4) = K * 10
expected_sum = K * sum(range(1, self.world_size + 1)) # K * 10 = K * 10
- expected = torch.full((M, N), expected_sum, dtype=torch.float32, device=self.device)
+ expected = torch.full(
+ (M, N), expected_sum, dtype=torch.float32, device=self.device
+ )
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
dist.destroy_process_group()
+
if __name__ == "__main__":
run_tests()
diff --git a/test/test_gemm_reduce_scatter.py b/test/test_gemm_reduce_scatter.py
index 55ab6ff..ad58c94 100644
--- a/test/test_gemm_reduce_scatter.py
+++ b/test/test_gemm_reduce_scatter.py
@@ -1,10 +1,9 @@
+from datetime import timedelta
import os
import sys
-from datetime import timedelta
import torch
import torch.distributed as dist
-import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
@@ -16,7 +15,7 @@
# Adjust the path to import the kernel from the 'kraken' project directory
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
-from kraken.reduce_scatter_fusion import (
+from kraken.fused import (
gemm_reduce_scatter,
gemm_reduce_scatter_ce_persistent,
)
@@ -55,10 +54,9 @@ def _get_expected_result(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor
group_name = dist.group.WORLD.group_name
# use torch symm mem's fused_matmul_reduce_scatter impl for testing
- expected_result = torch.ops.symm_mem.fused_matmul_reduce_scatter(
+ return torch.ops.symm_mem.fused_matmul_reduce_scatter(
a, b, "sum", scatter_dim=0, group_name=group_name
)
- return expected_result
@skip_if_lt_x_gpu(4)
def test_gemm_reduce_scatter(self):