diff --git a/README.md b/README.md index 45f4a4f..f431b2e 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,179 @@ -
# Kraken +[**🎯 Features**](#-features) | [**πŸš€ Getting Started**](#-getting-started) | [**πŸ’» Usage**](#-usage) | [**Benchmarks**](#-benchmarks) | [**🀝 Contributing**](#-contributing) | [**βš–οΈ License**](#️-license) + #### A Triton library of Symmetric Memory operators and examples.
+This repository aims to be a cookbook for developing distributed AI models using Triton and PyTorch's symmetric memory capabilities. + +This is NOT intended to be a "framework" or "library" - it is intended to provide some high-performance Triton implementations with in-kernel communication for developers to hack on :) Please copy-paste and fork as you desire. + -This repository aims to simplify the process of developing distributed AI models using Triton and PyTorch's symmetric memory capabilities. Our initial kernels are adapted from the [Symmetric Memory Recipes](https://github.com/yifuwang/symm-mem-recipes) by Yifu Wang. +In additional to that, it includes a set of benchmarks to help researchers and developers explore and evaluate their implmentations. -## Examples -TBD +Our initial kernels are adapted from the [Symmetric Memory Recipes](https://github.com/yifuwang/symm-mem-recipes) by Yifu Wang. -## Requirements -Kraken requires: -* Triton >= 3.3.0 -* PyTorch >= 2.6.0 -* Python >= 3.10 +## 🎯 Features +- Receipe for high-performance Triton implementations of `all_gather`, `all_reduce`, and `reduce_scatter`. +- Comm-comp fused kernels such as `gemm_one_shot_all_reduce_fused` for increased efficiency. +- A suite of benchmarks to measure and compare the performance of different comm + comp implementations. +- PTX utilities for synchronization primitives not yet supported by Triton. -## Installation +## πŸš€ Getting Started +### Prerequisites +- PyTorch (version 2.6.0 or higher) +- Triton (version 3.3.0) +- Python (version 3.10 or higher) +- CUDA (version 12.4 or higher) Version must matche your PyTorch installaltion. + +### Installation ```bash git clone https://github.com/meta-pytorch/kraken cd kraken pip install -e . -r requirements.txt ``` -## License -Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository. +## πŸ’» Usage +Rather than a rigid framework, Kraken is a hands-on tutorial: developers can embed its techniques into xformers, FlashAttention, TorchInductor-generated kernelsβ€”or any custom Triton code. + +There are two ways of using Kraken kernels: + + +You can import and use the Kraken kernels in your own PyTorch projects. Here is an example of how to use the `one_shot_all_reduce` kernel: + +```python +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import kraken +import os + +# setup distributed process group. +local_rank = int(os.environ["LOCAL_RANK"]) +torch.cuda.set_device(f"cuda:{local_rank}") +dist.init_process_group("nccl") + +# Create and initialize a symmetric memory tensor +# See blog: https://dev-discuss.pytorch.org/t/pytorch-symmetricmemory-harnessing-nvlink-programmability-with-ease/279 for symmetric memory details. +a_shared = symm_mem.empty( + (4096, 4096), + dtype=torch.bfloat16, + device=f"cuda:{local_rank}", + ) +symm_mem.rendezvous(a_shared, group=dist.group.WORLD) +a_shared = a_shared.normal_() + +# Call one_shot_all_reduce kernel from kraken. +a = kraken.comm.one_shot_all_reduce(a_shared) +``` +Remember to run with torchrun! Example torchrun command: +```shell +torchrun --nnodes 1 --nproc-per-node \ + --rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python \ + python3 example.py +``` + +Alternatively, you can build your own custom kernels by leveraging Kraken's low-level primitives. This allows you to create highly optimized kernels tailored to your specific needs. We provide PTX implementations of low-level primitives in `kraken._ptx_utils`. + +Here's an example of how to use `kraken._ptx_utils.symm_mem_sync` to synchronize blocks with matching `block_id` across participating devices in a custom kernel. This is often necessary before and after accessing symmetric memory tensors. + +```python +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + +import triton +import triton.language as tl + +import kraken +import os + +@triton.jit +def custom_distributed_kernel( + a_shared_ptrs, + a_signal_pad_ptrs, + rank: tl.constexpr, + world_size: tl.constexpr, +): + # Synchronizes blocks with matching block_id across participating devices. + # Ensures that all writes to a_shared from previous kernels across all devices + # are visible to the current kernel: + kraken._ptx_utils.symm_mem_sync( + a_signal_pad_ptrs, + None, + rank, + world_size, + hasPreviousMemAccess=False, + hasSubsequentMemAccess=True, + ) + ... # access a_shared via a_shared_ptrs. + +# Create and initialize a symmetric memory tensor +local_rank = int(os.environ["LOCAL_RANK"]) +torch.cuda.set_device(f"cuda:{local_rank}") +dist.init_process_group("nccl") +a_shared = symm_mem.empty((4096, 4096), dtype=torch.bfloat16, device=f"cuda:{local_rank}") +symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) + +# Define the grid for kernel launch. For simplicity, we use a single thread block. +grid = (1,) + +# Call custom kernel +custom_distributed_kernel[grid]( + symm_mem_hdl.buffer_ptrs_dev, + symm_mem_hdl.signal_pad_ptrs_dev, + rank=symm_mem_hdl.rank, + world_size=symm_mem_hdl.world_size, +) +``` + + +## πŸ“ Structure +Kraken is organized for easy hacking of distributed Triton kernel: + +### Example Kernels +#### `kraken.comm` +contains communication kernels with fine-grained sychronizations. +- `all_gather_w_progress` +- `one_shot_all_reduce` +- (coming soon) `two_shot_all_reduce` +- (coming soon) `multimem_all_reduce` +#### `kraken.fused` +Fused communication/computation kernels. +- All gather matmul: `all_gather_matmul` +- Gemm all reduce: `gemm_one_shot_all_reduce_fused` +- Gemm reduce scatter: `gemm_reduce_scatter`, `gemm_reduce_scatter_ce_persistent` +- Reduce bias: `one_shot_all_reduce_bias`, `two_shot_all_reduce_bias` +- Reduce bias rms_norm: `one_shot_all_reduce_bias_rms_norm`, `two_shot_all_reduce_bias_rms_norm` + +#### `kraken.quantized` +(comming soon) Fused communication/computation kernels with quantization. + + +### Inline PTX Utils +`kraken._ptx_utils` provides inline ptx implementation of memory barrier synchorinzations that are not natively supported by triton. + + + +### Benchmarks +Kraken includes a set of benchmarks in `benchmarks/` to evaluate the performance of its kernels. You can run them as follows: + +```bash +torchrun --nnodes 1 --nproc-per-node \ +--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python python3 \ +benchmark/benchmark_all_reduce.py +# ... and so on for other benchmarks +``` + +Run with `--help` to see configurable benchmark arguments for setting backends, dtype, shape etc. to profile. +```bash +python benchmark/benchmark_all_reduce.py --help +``` + + +## 🀝 Contributing +Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for more details on how to contribute to the project. + +## βš–οΈ License +Source code is made available under a [BSD 3 license](./LICENSE), however you may have other legal obligations that govern your use of other content linked in this repository. \ No newline at end of file diff --git a/benchmark/benchmark_all_gather_matmul.py b/benchmark/benchmark_all_gather_matmul.py index ccf70ee..b7fb1bd 100644 --- a/benchmark/benchmark_all_gather_matmul.py +++ b/benchmark/benchmark_all_gather_matmul.py @@ -3,7 +3,6 @@ import csv from dataclasses import asdict, dataclass import functools -import itertools import os import sys @@ -63,15 +62,10 @@ def asdict(self): def generate_experiment_configs( dtype: torch.dtype, - M: list[int], - N: list[int], - K: list[int], + shapes: list[tuple[int, int, int]], backends: list[str], device: torch.device, ) -> list[ExperimentConfig]: - # Generate cross config shapes from M, N, K lists - shapes = list(itertools.product(M, N, K)) - all_configs = [] for shape in shapes: all_configs.append( @@ -93,7 +87,7 @@ def get_single_backend_fn(backend: str): if backend == "torch_symm_mem": return torch_symm_mem_ag_mm if backend == "triton": - return kraken.all_gather.all_gather_matmul + return kraken.fused.all_gather_matmul raise NotImplementedError(backend) @@ -176,9 +170,7 @@ def main(args): torch.manual_seed(42 + local_rank) results = [] - configs = generate_experiment_configs( - args.dtype, args.M, args.N, args.K, args.backend, device - ) + configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device) for config in configs: results.append( Experiment( @@ -196,7 +188,7 @@ def shape_input_type(s): M, N, K = map(int, s.split(",")) return M, N, K except Exception as e: - raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e + raise argparse.ArgumentTypeError("Shape must be M, N, K") from e if __name__ == "__main__": @@ -228,27 +220,15 @@ def shape_input_type(s): ) parser.add_argument( - "-M", - type=shape_input_type, - nargs="+", - default=[2**x for x in range(7, 11)], - help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", - ) - - parser.add_argument( - "-N", + "--shape", type=shape_input_type, nargs="+", - default=[6656], - help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", - ) - - parser.add_argument( - "-K", - type=shape_input_type, - nargs="+", - default=[2**x for x in range(12, 15)], - help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", + default=[ + (m, 6656, k) + for m in [2**x for x in range(7, 11)] + for k in [2**x for x in range(12, 16)] + ], + help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)", ) parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16") diff --git a/benchmark/benchmark_all_reduce.py b/benchmark/benchmark_all_reduce.py index c240263..cd91285 100644 --- a/benchmark/benchmark_all_reduce.py +++ b/benchmark/benchmark_all_reduce.py @@ -114,7 +114,7 @@ def get_single_backend_fn(backend: str): if backend == "dist_2shot": return symm_mem_two_shot_all_reduce if backend == "triton_1shot": - return kraken.all_reduce.one_shot_all_reduce + return kraken.comm.one_shot_all_reduce if backend == "nccl": return nccl_ring raise NotImplementedError(backend) diff --git a/benchmark/benchmark_all_reduce_bias.py b/benchmark/benchmark_all_reduce_bias.py index 88f8a43..fd033f2 100644 --- a/benchmark/benchmark_all_reduce_bias.py +++ b/benchmark/benchmark_all_reduce_bias.py @@ -7,18 +7,15 @@ import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem +import kraken from kraken import _logging as log -from kraken.all_reduce_fusion import ( - one_shot_all_reduce_bias, - two_shot_all_reduce_bias, -) def one_shot_all_reduce_bias( x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor ) -> torch.Tensor: y = torch.empty_like(x) - one_shot_all_reduce_bias(symm_mem_input, x, bias, y) + kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y) return y @@ -26,7 +23,7 @@ def two_shot_all_reduce_bias( x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor ) -> torch.Tensor: y = torch.empty_like(x) - two_shot_all_reduce_bias(symm_mem_input, x, bias, y) + kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y) return y diff --git a/benchmark/benchmark_all_reduce_bias_rms_norm.py b/benchmark/benchmark_all_reduce_bias_rms_norm.py index 93b54c0..c8c5391 100644 --- a/benchmark/benchmark_all_reduce_bias_rms_norm.py +++ b/benchmark/benchmark_all_reduce_bias_rms_norm.py @@ -7,44 +7,42 @@ import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem +import kraken from kraken import _logging as log -from kraken.all_reduce_fusion import ( - rms_norm, - one_shot_all_reduce_bias, - one_shot_all_reduce_bias_rms_norm, - two_shot_all_reduce_bias, - two_shot_all_reduce_bias_rms_norm, -) def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input): y = torch.empty_like(x) - one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y) + kraken.fused.one_shot_all_reduce_bias_rms_norm( + symm_mem_input, x, bias, rms_weight, y + ) return y def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input): y = torch.empty_like(x) - one_shot_all_reduce_bias(symm_mem_input, x, bias, y) - return rms_norm(y, rms_weight) + kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y) + return kraken.fused.rms_norm(y, rms_weight) def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input): y = torch.empty_like(x) - two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y) + kraken.fused.two_shot_all_reduce_bias_rms_norm( + symm_mem_input, x, bias, rms_weight, y + ) return y def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input): y = torch.empty_like(x) - two_shot_all_reduce_bias(symm_mem_input, x, bias, y) - return rms_norm(y, rms_weight) + kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y) + return kraken.fused.rms_norm(y, rms_weight) def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight): dist.all_reduce(x) y = x + bias - return rms_norm(y, rms_weight) + return kraken.fused.rms_norm(y, rms_weight) def create_benchmarks(b, t, d_size, device, dtype): diff --git a/benchmark/benchmark_matmul_reduce_scatter.py b/benchmark/benchmark_matmul_reduce_scatter.py index e8179da..119824f 100644 --- a/benchmark/benchmark_matmul_reduce_scatter.py +++ b/benchmark/benchmark_matmul_reduce_scatter.py @@ -1,16 +1,15 @@ import argparse +from collections import defaultdict import csv +from dataclasses import asdict, dataclass import functools -import itertools import os import sys -from collections import defaultdict -from dataclasses import asdict, dataclass +from tabulate import tabulate import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem -from tabulate import tabulate # Add the kraken directory to the Python path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -20,18 +19,22 @@ def torch_symm_mem_gemm_rs(a, b): - gemm_rs_output = torch.ops.symm_mem.fused_matmul_reduce_scatter( + return torch.ops.symm_mem.fused_matmul_reduce_scatter( a, b, "sum", scatter_dim=0, group_name=dist.group.WORLD.group_name ) - return gemm_rs_output + def nccl_mem_gemm_rs(a, b): - from torch.distributed._functional_collectives import reduce_scatter_tensor, wait_tensor + from torch.distributed._functional_collectives import ( + reduce_scatter_tensor, + wait_tensor, + ) gemm_output = torch.matmul(a, b) - rs_o = reduce_scatter_tensor(gemm_output, "sum", scatter_dim=0, group=dist.group.WORLD) - gemm_rs_output = wait_tensor(rs_o) - return gemm_rs_output + rs_o = reduce_scatter_tensor( + gemm_output, "sum", scatter_dim=0, group=dist.group.WORLD + ) + return wait_tensor(rs_o) @dataclass(frozen=True) @@ -64,15 +67,10 @@ def asdict(self): def generate_experiment_configs( dtype: torch.dtype, - M: list[int], - N: list[int], - K: list[int], + shapes: list[tuple[int, int, int]], backends: list[str], device: torch.device, ) -> list[ExperimentConfig]: - # Generate cross config shapes from M, N, K lists - shapes = list(itertools.product(M, N, K)) - all_configs = [] for shape in shapes: all_configs.append( @@ -94,7 +92,7 @@ def get_single_backend_fn(backend: str): if backend == "torch_symm_mem": return torch_symm_mem_gemm_rs if backend == "triton": - return kraken.reduce_scatter_fusion.gemm_reduce_scatter + return kraken.fused.gemm_reduce_scatter raise NotImplementedError(backend) @@ -177,9 +175,7 @@ def main(args): torch.manual_seed(42 + local_rank) results = [] - configs = generate_experiment_configs( - args.dtype, args.M, args.N, args.K, args.backend, device - ) + configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device) for config in configs: results.append( Experiment( @@ -197,7 +193,7 @@ def shape_input_type(s): M, N, K = map(int, s.split(",")) return M, N, K except Exception as e: - raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e + raise argparse.ArgumentTypeError("Shape must be M, N, K") from e if __name__ == "__main__": @@ -229,27 +225,15 @@ def shape_input_type(s): ) parser.add_argument( - "-M", + "--shape", type=shape_input_type, nargs="+", - default=[2**x for x in range(7, 11)], - help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", - ) - - parser.add_argument( - "-N", - type=shape_input_type, - nargs="+", - default=[6656], - help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", - ) - - parser.add_argument( - "-K", - type=shape_input_type, - nargs="+", - default=[2**x for x in range(12, 16)], - help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)", + default=[ + (m, 6656, k) + for m in [2**x for x in range(7, 11)] + for k in [2**x for x in range(12, 16)] + ], + help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)", ) parser.add_argument("-dtype", type=str, help="dtype", default="float32") diff --git a/kraken/__init__.py b/kraken/__init__.py index dffc887..ec4d3e4 100644 --- a/kraken/__init__.py +++ b/kraken/__init__.py @@ -1,3 +1,13 @@ -from . import _logging, all_gather, all_reduce, all_reduce_fusion, reduce_scatter_fusion +from . import ( + _logging, + _ptx_utils, + comm, + fused, +) -__all__ = ["_logging", "all_gather", "all_reduce", "all_reduce_fusion", "reduce_scatter_fusion"] +__all__ = [ + "_logging", + "_ptx_utils", + "comm", + "fused", +] diff --git a/kraken/_ptx_utils/__init__.py b/kraken/_ptx_utils/__init__.py index 6d9493a..bf73a21 100644 --- a/kraken/_ptx_utils/__init__.py +++ b/kraken/_ptx_utils/__init__.py @@ -1,15 +1,19 @@ from .gmem_barrier_arrive_wait import arrive_gmem_barrier, wait_gmem_barrier from .symm_mem_barrier import ( _get_flat_tid as get_flat_tid, +) +from .symm_mem_barrier import ( _send_signal as send_signal, +) +from .symm_mem_barrier import ( symm_mem_sync as symm_mem_sync, ) __all__ = [ "arrive_gmem_barrier", - "symm_mem_sync", - "wait_gmem_barrier", "get_flat_tid", "send_signal", + "symm_mem_sync", + "wait_gmem_barrier", ] # Avoid ptx_utils when possible diff --git a/kraken/_ptx_utils/symm_mem_barrier.py b/kraken/_ptx_utils/symm_mem_barrier.py index b42844b..a04f327 100644 --- a/kraken/_ptx_utils/symm_mem_barrier.py +++ b/kraken/_ptx_utils/symm_mem_barrier.py @@ -101,7 +101,7 @@ def symm_mem_sync( rank: tl.constexpr, world_size: tl.constexpr, hasPreviousMemAccess: tl.constexpr = False, - hasSubsequenceMemAccess: tl.constexpr = False, + hasSubsequentMemAccess: tl.constexpr = False, ): """ Synchronizes blocks with matching block_id across participating devices. @@ -112,17 +112,17 @@ def symm_mem_sync( Pattern 0: Ensures that all writes to symm_mem buffers from previous kernels across all devices are visible to the current kernel: - symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequenceMemAccess=True) + symm_mem_sync(..., hasPreviousMemAccess=False, hasSubsequentMemAccess=True) Pattern 1: Ensures that all writes to symm_mem buffers from the current block are visible to all remote blocks with matching blockIdx: - symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequenceMemAccess=True) + symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=True) Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe for writing by subsequent kernels across all devices. - symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequenceMemAccess=False) + symm_mem_sync(..., hasPreviousMemAccess=True, hasSubsequentMemAccess=False) CUDA graph friendliness: @@ -152,7 +152,7 @@ def symm_mem_sync( if flat_tid < world_size: _send_signal(send_addrs, "release" if hasPreviousMemAccess else "relaxed") - _wait_signal(wait_addrs, "acquire" if hasSubsequenceMemAccess else "relaxed") + _wait_signal(wait_addrs, "acquire" if hasSubsequentMemAccess else "relaxed") - if hasSubsequenceMemAccess: + if hasSubsequentMemAccess: tl.debug_barrier() diff --git a/kraken/all_gather/__init__.py b/kraken/all_gather/__init__.py deleted file mode 100644 index 605884e..0000000 --- a/kraken/all_gather/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .all_gather_matmul import ( - all_gather_matmul as all_gather_matmul, -) - -__all__ = ["all_gather_matmul"] diff --git a/kraken/all_gather/copy_engine_all_gather.py b/kraken/all_gather/copy_engine_all_gather.py deleted file mode 100644 index e89b01a..0000000 --- a/kraken/all_gather/copy_engine_all_gather.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem - - -def copy_engine_all_gather_w_progress( - output: torch.Tensor, - inp: torch.Tensor, # Must be symmetric tensor - progress: torch.Tensor, - splits_per_rank: int, - backend_stream: torch.cuda.Stream | None = None, -) -> torch.cuda.Stream: - backend_stream = symm_mem._get_backend_stream(priority=-1) - assert inp.is_contiguous() - - symm_mem_hdl = symm_mem.rendezvous(inp, group=dist.group.WORLD) - assert symm_mem_hdl is not None - - rank = symm_mem_hdl.rank - world_size = symm_mem_hdl.world_size - - assert inp.numel() % splits_per_rank == 0 - assert progress.numel() >= world_size * splits_per_rank - - output_shape = list(inp.shape) - output_shape[0] *= world_size - assert list(output.shape) == output_shape, (list(output.shape), output_shape) - - chunks = output.chunk(world_size * splits_per_rank) - - symm_mem_hdl.barrier() - backend_stream.wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(backend_stream): - for step in range(world_size): - src_rank = (rank + step + 1) % world_size - for split_id in range(splits_per_rank): - src_buf = symm_mem_hdl.get_buffer( - src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id - ) - chunks[src_rank * splits_per_rank + split_id].copy_(src_buf) - # cuStreamWriteValue32 issues a system level fence before the write - symm_mem_hdl.stream_write_value32( - progress, - offset=src_rank * splits_per_rank + split_id, - val=1, - ) - symm_mem_hdl.barrier() - - return backend_stream diff --git a/kraken/all_reduce/__init__.py b/kraken/all_reduce/__init__.py deleted file mode 100644 index 0c48c58..0000000 --- a/kraken/all_reduce/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .one_shot_all_reduce import ( - one_shot_all_reduce as one_shot_all_reduce, -) - -__all__ = ["one_shot_all_reduce"] diff --git a/kraken/comm/__init__.py b/kraken/comm/__init__.py new file mode 100644 index 0000000..58fb6c6 --- /dev/null +++ b/kraken/comm/__init__.py @@ -0,0 +1,13 @@ +from .copy_engine_all_gather import ( + _copy_engine_all_gather_w_progress, + all_gather_w_progress, +) +from .one_shot_all_reduce import ( + one_shot_all_reduce as one_shot_all_reduce, +) + +__all__ = [ + "_copy_engine_all_gather_w_progress", + "all_gather_w_progress", + "one_shot_all_reduce", +] diff --git a/kraken/comm/copy_engine_all_gather.py b/kraken/comm/copy_engine_all_gather.py new file mode 100644 index 0000000..d817520 --- /dev/null +++ b/kraken/comm/copy_engine_all_gather.py @@ -0,0 +1,126 @@ +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + + +def _copy_engine_all_gather_w_progress( + output: torch.Tensor, + inp: torch.Tensor, # Must be symmetric tensor + progress: torch.Tensor, + splits_per_rank: int, + backend_stream: torch.cuda.Stream | None = None, +) -> torch.cuda.Stream: + backend_stream = symm_mem._get_backend_stream(priority=-1) + assert inp.is_contiguous() + + symm_mem_hdl = symm_mem.rendezvous(inp, group=dist.group.WORLD) + assert symm_mem_hdl is not None + + rank = symm_mem_hdl.rank + world_size = symm_mem_hdl.world_size + + assert inp.numel() % splits_per_rank == 0 + assert progress.numel() >= world_size * splits_per_rank + + output_shape = list(inp.shape) + output_shape[0] *= world_size + assert list(output.shape) == output_shape, (list(output.shape), output_shape) + + # Split the output tensor into chunks for each rank and split. + chunks = output.chunk(world_size * splits_per_rank) + + # Synchronize all ranks before starting the copy operations. + # This ensures any previous operations on the symmetric memory tensor are completed. + symm_mem_hdl.barrier() + backend_stream.wait_stream(torch.cuda.current_stream()) + + # Perform the all-gather operation on the backend stream. + with torch.cuda.stream(backend_stream): + # Iterate through source rank and splits of the source rank. + for step in range(world_size): + src_rank = (rank + step + 1) % world_size + for split_id in range(splits_per_rank): + src_buf = symm_mem_hdl.get_buffer( + src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id + ) + # Copy data from the source buffer to the corresponding output chunk using copy engine. + chunks[src_rank * splits_per_rank + split_id].copy_(src_buf) + # Signal the completion of the copy for this chunk in progress tensor. + # cuStreamWriteValue32 issues a system level fence before the write + symm_mem_hdl.stream_write_value32( + progress, + offset=src_rank * splits_per_rank + split_id, + val=1, + ) + + # Synchronize all ranks after all copy operations are issued. + # This ensures all copy operations are completed before proceeding. + symm_mem_hdl.barrier() + + return backend_stream + + +def all_gather_w_progress( + a_shared: torch.Tensor, + a_out: torch.Tensor | None = None, + progress: torch.Tensor | None = None, + **kwargs, +) -> torch.Tensor: + """ + Performs an all-gather operation using the copy engine and tracks progress. + + This function gathers data from all ranks into a single output tensor. It uses + the copy engine for the data transfer and a progress tensor to signal the + completion of each chunk copy in the progress tensor. + The operation is performed on a backend CUDA stream. + + Args: + a_shared (torch.Tensor): The input tensor, which must be a symmetric tensor. + Each rank provides its shard of the data in this tensor. + a_out (torch.Tensor, optional): The output tensor to store the gathered data. + progress (torch.Tensor, optional): A tensor to track the progress of the copy + operations. Its size should be at least `world_size * splits_per_rank`. + Initially, all elements should be zero. After a chunk is copied, + the corresponding element is set to 1. + splits_per_rank (int): The number of splits (chunks) per rank. + backend_stream (torch.cuda.Stream, optional): A background CUDA stream for + the copy engine operations. If not provided, a new stream is created. + + Returns: + torch.Tensor: The output tensor containing the gathered data from all ranks. + """ + configs = { + "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1), + } + + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) + + a_shape = list(a_shared.shape) + a_shape[0] *= symm_mem_hdl.world_size + + configs["RANK"] = symm_mem_hdl.rank + configs["WORLD_SIZE"] = symm_mem_hdl.world_size + + configs["COMM_BLOCK_SIZE_M"] = ( + a_shape[0] // configs["WORLD_SIZE"] // configs["SPLITS_PER_RANK"] + ) + + if a_out is None: + a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device) + + if progress is None: + progress = torch.zeros( + symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"], + dtype=torch.uint32, + device=a_shared.device, + ) + else: + progress.fill_(0) # Reset progress to 0. + + backend_stream = _copy_engine_all_gather_w_progress( + a_out, a_shared, progress, configs["SPLITS_PER_RANK"] + ) + + torch.cuda.current_stream().wait_stream(backend_stream) + + return a_out diff --git a/kraken/all_reduce/one_shot_all_reduce.py b/kraken/comm/one_shot_all_reduce.py similarity index 53% rename from kraken/all_reduce/one_shot_all_reduce.py rename to kraken/comm/one_shot_all_reduce.py index 6afed6f..7080c21 100644 --- a/kraken/all_reduce/one_shot_all_reduce.py +++ b/kraken/comm/one_shot_all_reduce.py @@ -1,3 +1,10 @@ +""" +This module implements a Triton kernel for one-shot all-reduce. +This kernel performs an all-reduce operation on a Torch symmetric memory tensor distributed across +multiple devices. According to benchmark results, one-shot all reduce outperforms NCCL ring reduce +for small message sizes (<~400KB on a 8xH100 system with NVSwitch). +""" + import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem @@ -17,19 +24,22 @@ def one_shot_all_reduce_kernel( world_size: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): + # Synchronize blocks with matching block_id across all participating devices before starting. + # This ensures that all previous memory operations are visible. ptx_utils.symm_mem_sync( - signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True + signal_pad_ptrs, None, rank, world_size, hasSubsequentMemAccess=True ) pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE while block_start < numel: - # Each thread processes 128 bits. offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < numel acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16) + + # Iteratively load from each rank's buffer and accumulate. `static_range` unrolls the loop at compile time, enabling efficient iteration over `buf_tuple`. for i in tl.static_range(world_size): buffer_rank = buf_tuple[i] x = tl.load(buffer_rank + offsets, mask=mask) @@ -37,12 +47,29 @@ def one_shot_all_reduce_kernel( tl.store(output_ptr + offsets, acc, mask=mask) block_start += tl.num_programs(axis=0) * BLOCK_SIZE + # Synchronize all participating devices after the reduction is complete. + # Subsequent kernel cannot overwrite the symmetric memory buffer until all devices reach this point. ptx_utils.symm_mem_sync( signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True ) def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Perform a one-shot all-reduce operation using symmetric memory. + + output = all_reduce(input) + + Args: + tensor (torch.Tensor): The input tensor to be reduced. Must be of dtype torch.bfloat16 and 128-bit aligned. + **kwargs: Additional keyword arguments for kernel configuration: + max_num_blocks (int, optional): The maximum number of blocks to launch. + num_warps (int, optional): The number of warps per block. + BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel. + + Returns: + torch.Tensor: The output tensor containing the reduced result. + """ config = { "max_num_blocks": kwargs.get("max_num_blocks", 24), "num_warps": kwargs.get("num_warps", 32), @@ -62,12 +89,16 @@ def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor: symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD) output = torch.empty_like(tensor) + # Get the buffer pointers for each rank from the symmetric memory handle, and pass them as a tuple to the triton kernel. buf_list = [ symm_mem_hdl.get_buffer(i, tuple(tensor.shape), tensor.dtype) for i in range(symm_mem_hdl.world_size) ] buf_tuple = tuple(buf_list) + # symm_mem_hdl.signal_pad_ptrs_dev: An array of pointers pointing to signal_pads for each rank. + # A signal pad is a memory region used for synchronization between devices. + # `symm_mem_sync` kernel uses these signal pads to implement a cross-device barrier to ensure memory visibility of symmetric memory tensors. one_shot_all_reduce_kernel[(num_blocks, 1, 1)]( buf_tuple, symm_mem_hdl.signal_pad_ptrs_dev, diff --git a/kraken/all_reduce_fusion/__init__.py b/kraken/fused/__init__.py similarity index 70% rename from kraken/all_reduce_fusion/__init__.py rename to kraken/fused/__init__.py index 22b8e4d..1dc9068 100644 --- a/kraken/all_reduce_fusion/__init__.py +++ b/kraken/fused/__init__.py @@ -1,21 +1,27 @@ -from .rms_norm import rms_norm +from .all_gather_matmul import all_gather_matmul from .gemm_one_shot_all_reduce_fused import ( gemm_one_shot_all_reduce as gemm_one_shot_all_reduce_fused, ) +from .gemm_reduce_scatter_ce_persistent import gemm_reduce_scatter_ce_persistent +from .gemm_reduce_scatter_fused import gemm_reduce_scatter from .one_shot_all_reduce_bias import one_shot_all_reduce_bias from .one_shot_all_reduce_bias_rms_norm import ( one_shot_all_reduce_bias_rms_norm, ) +from .rms_norm import rms_norm from .two_shot_all_reduce_bias import two_shot_all_reduce_bias from .two_shot_all_reduce_bias_rms_norm import ( two_shot_all_reduce_bias_rms_norm, ) __all__ = [ - "rms_norm", + "all_gather_matmul", "gemm_one_shot_all_reduce_fused", + "gemm_reduce_scatter", + "gemm_reduce_scatter_ce_persistent", "one_shot_all_reduce_bias", "one_shot_all_reduce_bias_rms_norm", + "rms_norm", "two_shot_all_reduce_bias", "two_shot_all_reduce_bias_rms_norm", ] diff --git a/kraken/all_gather/all_gather_matmul.py b/kraken/fused/all_gather_matmul.py similarity index 76% rename from kraken/all_gather/all_gather_matmul.py rename to kraken/fused/all_gather_matmul.py index edfc566..19189bd 100644 --- a/kraken/all_gather/all_gather_matmul.py +++ b/kraken/fused/all_gather_matmul.py @@ -6,7 +6,7 @@ import triton.tools.experimental_descriptor from .._ptx_utils import wait_gmem_barrier -from .copy_engine_all_gather import copy_engine_all_gather_w_progress +from ..comm import _copy_engine_all_gather_w_progress def _matmul_launch_metadata(grid, kernel, args): @@ -43,7 +43,11 @@ def _matmul_kernel_tma_persistent_w_progress( NUM_SMS: tl.constexpr, ): """ - Slightly modified from the sm90 tma persistent Triton tutorial. + Persistent Triton kernel for matrix multiplication with progress waiting. + + This kernel performs matrix multiplication (`C = A @ B`) in a persistent manner. + It waits for chunks of the `A` matrix to be gathered from other ranks by + monitoring a `progress_ptr` before consuming them. """ dtype = tl.float8e4nv if FP8_OUTPUT else tl.bfloat16 @@ -237,6 +241,44 @@ def all_gather_matmul( progress: torch.Tensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Performs a fused all-gather and matrix multiplication operation. + + This function first performs an all-gather operation on the `a_shared` tensor + to construct the full `A` matrix. It then performs a matrix multiplication + `C = A @ B`. The all-gather is performed by the copy engine on a separate + stream, and the matrix multiplication is performed by a persistent Triton + kernel that waits for the data to be gathered. + + Args: + a_shared (torch.Tensor): The local shard of the `A` matrix. This must + be a symmetric tensor. + b (torch.Tensor): The `B` matrix. + a_out (torch.Tensor | None, optional): The output tensor for the + all-gathered `A` matrix. If None, a new tensor is created. + progress (torch.Tensor | None, optional): A tensor for tracking the + progress of the all-gather operation. If None, a new tensor is + created. + **kwargs: Additional keyword arguments for kernel configuration: + splits_per_rank (int, optional): The number of splits for the + all-gather operation. Defaults to 1. + block_size_m (int, optional): The block size for the M dimension. + Defaults to 128. + block_size_n (int, optional): The block size for the N dimension. + Defaults to 256. + block_size_k (int, optional): The block size for the K dimension. + Defaults to 64. + group_size_m (int, optional): The group size for the M dimension. + Defaults to 4. + num_stages (int, optional): The number of stages for the matmul + kernel. Defaults to 3. + num_warps (int, optional): The number of warps for the matmul + kernel. Defaults to 8. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing the all-gathered + `A` matrix and the result of the matrix multiplication `C`. + """ configs = { "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1), "BLOCK_SIZE_M": kwargs.get("block_size_m", 128), @@ -285,10 +327,12 @@ def all_gather_matmul( else: progress.fill_(0) # Reset progress to 0. - backend_stream = copy_engine_all_gather_w_progress( + # Perform all-gather using the copy engine on a backend stream. + backend_stream = _copy_engine_all_gather_w_progress( a_out, a_shared, progress, configs["SPLITS_PER_RANK"] ) + # Perform matrix multiplication on gathered a, which waits for signal of completion for each chunk of a. c = _matmul_w_progress(a_out, a_shared, b, progress, configs) torch.cuda.current_stream().wait_stream(backend_stream) diff --git a/kraken/all_reduce_fusion/gemm_one_shot_all_reduce_fused.py b/kraken/fused/gemm_one_shot_all_reduce_fused.py similarity index 96% rename from kraken/all_reduce_fusion/gemm_one_shot_all_reduce_fused.py rename to kraken/fused/gemm_one_shot_all_reduce_fused.py index 68395c1..56724b4 100644 --- a/kraken/all_reduce_fusion/gemm_one_shot_all_reduce_fused.py +++ b/kraken/fused/gemm_one_shot_all_reduce_fused.py @@ -78,7 +78,7 @@ def gemm_one_shot_all_reduce_kernel( # Synchronize before all-reduce ptx_utils.symm_mem_sync( - signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True + signal_pad_ptrs, None, rank, world_size, hasSubsequentMemAccess=True ) # All-reduce: sum results from all ranks @@ -115,9 +115,9 @@ def gemm_one_shot_all_reduce( Output matrix of shape (M, N) containing the all-reduced result """ - assert ( - a.shape[1] == b.shape[0] - ), "Inner dimensions must match for matrix multiplication" + assert a.shape[1] == b.shape[0], ( + "Inner dimensions must match for matrix multiplication" + ) M, K = a.shape K, N = b.shape diff --git a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py b/kraken/fused/gemm_reduce_scatter_ce_persistent.py similarity index 99% rename from kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py rename to kraken/fused/gemm_reduce_scatter_ce_persistent.py index 97d358b..413f529 100644 --- a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_ce_persistent.py +++ b/kraken/fused/gemm_reduce_scatter_ce_persistent.py @@ -1,10 +1,9 @@ +from cuda.bindings import driver import torch import torch.distributed as dist import torch.distributed._symmetric_memory as symm_mem - import triton import triton.language as tl -from cuda.bindings import driver from .._ptx_utils import get_flat_tid, send_signal @@ -189,7 +188,7 @@ def gemm_producer_w_progress( ): M, K = a.shape Kb, N = b.shape - assert K == Kb, "Inner dimensions must match for matrix multiplication" + assert Kb == K, "Inner dimensions must match for matrix multiplication" assert a.dtype == b.dtype, "Input dtypes must match" bT = b.T diff --git a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused.py b/kraken/fused/gemm_reduce_scatter_fused.py similarity index 95% rename from kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused.py rename to kraken/fused/gemm_reduce_scatter_fused.py index da88f47..fca06c1 100644 --- a/kraken/reduce_scatter_fusion/gemm_reduce_scatter_fused.py +++ b/kraken/fused/gemm_reduce_scatter_fused.py @@ -61,7 +61,7 @@ def gemm_reduce_scatter_kernel( # GEMM Computation accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for k in range(tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0) b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), other=0.0) accumulator = tl.dot(a, b, accumulator) @@ -91,7 +91,7 @@ def gemm_reduce_scatter_kernel( # synchronize ptx_utils.symm_mem_sync( - signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True + signal_pad_ptrs, None, rank, world_size, hasSubsequentMemAccess=True ) # Reduce Scatter logic: For each tile in the rank's assigned row slice (along M), @@ -161,9 +161,9 @@ def gemm_reduce_scatter(a: torch.Tensor, b: torch.Tensor, **kwargs) -> torch.Ten Output matrix of shape (M / world_size, N) containing the reduce-scattered result. """ - assert ( - a.shape[1] == b.shape[0] - ), "Inner dimensions must match for matrix multiplication" + assert a.shape[1] == b.shape[0], ( + "Inner dimensions must match for matrix multiplication" + ) M, K = a.shape _, N = b.shape @@ -171,9 +171,9 @@ def gemm_reduce_scatter(a: torch.Tensor, b: torch.Tensor, **kwargs) -> torch.Ten world_size = dist.get_world_size(group) rank = dist.get_rank(group) - assert ( - M % world_size == 0 - ), f"M dimension ({M}) must be divisible by world_size ({world_size})" + assert M % world_size == 0, ( + f"M dimension ({M}) must be divisible by world_size ({world_size})" + ) # Configuration stuff BLOCK_SIZE_M = kwargs.get("BLOCK_SIZE_M", 64) diff --git a/kraken/all_reduce_fusion/one_shot_all_reduce_bias.py b/kraken/fused/one_shot_all_reduce_bias.py similarity index 99% rename from kraken/all_reduce_fusion/one_shot_all_reduce_bias.py rename to kraken/fused/one_shot_all_reduce_bias.py index 5fcb83c..45f78f4 100644 --- a/kraken/all_reduce_fusion/one_shot_all_reduce_bias.py +++ b/kraken/fused/one_shot_all_reduce_bias.py @@ -70,7 +70,7 @@ def one_shot_all_reduce_bias_kernel( rank, world_size, hasPreviousMemAccess=True, - hasSubsequenceMemAccess=True, + hasSubsequentMemAccess=True, ) block_start = pid * BLOCK_SIZE diff --git a/kraken/all_reduce_fusion/one_shot_all_reduce_bias_rms_norm.py b/kraken/fused/one_shot_all_reduce_bias_rms_norm.py similarity index 99% rename from kraken/all_reduce_fusion/one_shot_all_reduce_bias_rms_norm.py rename to kraken/fused/one_shot_all_reduce_bias_rms_norm.py index c16c5e9..9e7c3e7 100644 --- a/kraken/all_reduce_fusion/one_shot_all_reduce_bias_rms_norm.py +++ b/kraken/fused/one_shot_all_reduce_bias_rms_norm.py @@ -66,7 +66,7 @@ def one_shot_all_reduce_bias_rms_norm_kernel( rank, world_size, hasPreviousMemAccess=True, - hasSubsequenceMemAccess=True, + hasSubsequentMemAccess=True, ) # Allreduce + bias diff --git a/kraken/all_reduce_fusion/rms_norm.py b/kraken/fused/rms_norm.py similarity index 100% rename from kraken/all_reduce_fusion/rms_norm.py rename to kraken/fused/rms_norm.py diff --git a/kraken/all_reduce_fusion/two_shot_all_reduce_bias.py b/kraken/fused/two_shot_all_reduce_bias.py similarity index 98% rename from kraken/all_reduce_fusion/two_shot_all_reduce_bias.py rename to kraken/fused/two_shot_all_reduce_bias.py index 05128f6..48ddf2f 100644 --- a/kraken/all_reduce_fusion/two_shot_all_reduce_bias.py +++ b/kraken/fused/two_shot_all_reduce_bias.py @@ -60,7 +60,7 @@ def two_shot_all_reduce_bias_kernel( rank, world_size, hasPreviousMemAccess=True, - hasSubsequenceMemAccess=True, + hasSubsequentMemAccess=True, ) # Two-shot allreduce @@ -95,7 +95,7 @@ def two_shot_all_reduce_bias_kernel( rank, world_size, hasPreviousMemAccess=True, - hasSubsequenceMemAccess=True, + hasSubsequentMemAccess=True, ) # Copy the result from the symmetric memory buffer to the output. diff --git a/kraken/all_reduce_fusion/two_shot_all_reduce_bias_rms_norm.py b/kraken/fused/two_shot_all_reduce_bias_rms_norm.py similarity index 98% rename from kraken/all_reduce_fusion/two_shot_all_reduce_bias_rms_norm.py rename to kraken/fused/two_shot_all_reduce_bias_rms_norm.py index c7e4226..f60d39e 100644 --- a/kraken/all_reduce_fusion/two_shot_all_reduce_bias_rms_norm.py +++ b/kraken/fused/two_shot_all_reduce_bias_rms_norm.py @@ -72,7 +72,7 @@ def two_shot_all_reduce_bias_rms_norm_kernel( rank, world_size, hasPreviousMemAccess=True, - hasSubsequenceMemAccess=True, + hasSubsequentMemAccess=True, ) # Two shot allreduce @@ -104,7 +104,7 @@ def two_shot_all_reduce_bias_rms_norm_kernel( rank, world_size, hasPreviousMemAccess=True, - hasSubsequenceMemAccess=True, + hasSubsequentMemAccess=True, ) # The regular RMSNorm diff --git a/kraken/reduce_scatter_fusion/__init__.py b/kraken/reduce_scatter_fusion/__init__.py deleted file mode 100644 index 3ebae8b..0000000 --- a/kraken/reduce_scatter_fusion/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .gemm_reduce_scatter_ce_persistent import gemm_reduce_scatter_ce_persistent -from .gemm_reduce_scatter_fused import gemm_reduce_scatter - -__all__ = ["gemm_reduce_scatter", "gemm_reduce_scatter_ce_persistent"] diff --git a/requirements.txt b/requirements.txt index 7d0fafa..321e4fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Saw a worse performance when using older triton version. # This will ensure that people can reproduce the same performance. -triton >= 3.3.0 +triton == 3.3.0 torch >= 2.6.0; python_version >= "3.10" tyro cuda-bindings diff --git a/test/test_allgather.py b/test/test_allgather.py new file mode 100644 index 0000000..40dc8a7 --- /dev/null +++ b/test/test_allgather.py @@ -0,0 +1,78 @@ +import os +import sys + +import torch +import torch.distributed as dist +import torch.distributed._functional_collectives as fc +import torch.distributed._symmetric_memory as symm_mem +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, +) + +# Add the parent directory to the Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +import kraken + + +@instantiate_parametrized_tests +class TritonAllGatherTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + # world_size > 2 is needed to verify accumulation order + return torch.cuda.device_count() + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(42 + self.rank) + + @skip_if_lt_x_gpu(4) + def test_all_gather_w_progress(self): + self._init_process() + group_name = dist.group.WORLD.group_name + a_shared = symm_mem.empty( + (1024, 1024), + dtype=torch.bfloat16, + device=self.device, + ).normal_() + symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name) + + progress = torch.zeros( + symm_mem_hdl.world_size, + dtype=torch.uint32, + device=self.device, + ) + + golden_a = a_shared.clone() + a_gathered = fc.all_gather_tensor(golden_a, 0, "0") + + a_out = kraken.comm.all_gather_w_progress(a_shared, progress=progress) + + torch.testing.assert_close(a_out, a_gathered) + assert torch.all(progress != 0) + + dist.destroy_process_group() + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_allgather_matmul.py b/test/test_allgather_matmul.py index fb7de0b..a203128 100644 --- a/test/test_allgather_matmul.py +++ b/test/test_allgather_matmul.py @@ -64,7 +64,7 @@ def test_all_gather_matmul(self): ).T.contiguous() b = bT.T - ag, c = kraken.all_gather.all_gather_matmul(a_shared, b) + ag, c = kraken.fused.all_gather_matmul(a_shared, b) golden_a = a_shared.clone() ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( diff --git a/test/test_allreduce.py b/test/test_allreduce.py index 0725f78..b3a30c7 100644 --- a/test/test_allreduce.py +++ b/test/test_allreduce.py @@ -57,7 +57,7 @@ def test_one_shot(self): input_tensor = input_tensor.normal_() symm_mem.rendezvous(input_tensor, group_name) - result = kraken.all_reduce.one_shot_all_reduce(input_tensor) + result = kraken.comm.one_shot_all_reduce(input_tensor) golden = input_tensor.clone() dist.all_reduce(golden) diff --git a/test/test_allreduce_bias.py b/test/test_allreduce_bias.py index 4b311f6..0ee7683 100644 --- a/test/test_allreduce_bias.py +++ b/test/test_allreduce_bias.py @@ -12,7 +12,7 @@ run_tests, ) -from kraken.all_reduce_fusion import ( +from kraken.fused import ( one_shot_all_reduce_bias, two_shot_all_reduce_bias, ) diff --git a/test/test_allreduce_bias_rms_norm.py b/test/test_allreduce_bias_rms_norm.py index 139e237..8ab16df 100644 --- a/test/test_allreduce_bias_rms_norm.py +++ b/test/test_allreduce_bias_rms_norm.py @@ -12,9 +12,9 @@ run_tests, ) -from kraken.all_reduce_fusion import ( - rms_norm, +from kraken.fused import ( one_shot_all_reduce_bias_rms_norm, + rms_norm, two_shot_all_reduce_bias_rms_norm, ) @@ -74,9 +74,7 @@ def test_one_shot_bias_rms_norm(self): bias = torch.randn(b, 5120, device=self.device, dtype=torch.bfloat16) w = torch.randn(5120, device=self.device, dtype=torch.bfloat16) y = torch.empty_like(input_tensor) - one_shot_all_reduce_bias_rms_norm( - symm_mem_buffer, input_tensor, bias, w, y - ) + one_shot_all_reduce_bias_rms_norm(symm_mem_buffer, input_tensor, bias, w, y) baseline = self._nccl_all_reduce_bias_rms_norm( input_tensor.clone(), w.clone(), bias.clone() ) @@ -107,9 +105,7 @@ def test_two_shot_bias_rms_norm(self): bias = torch.randn(b, 5120, device=self.device, dtype=torch.bfloat16) w = torch.randn(5120, device=self.device, dtype=torch.bfloat16) y = torch.empty_like(input_tensor) - two_shot_all_reduce_bias_rms_norm( - symm_mem_buffer, input_tensor, bias, w, y - ) + two_shot_all_reduce_bias_rms_norm(symm_mem_buffer, input_tensor, bias, w, y) baseline = self._nccl_all_reduce_bias_rms_norm( input_tensor.clone(), w.clone(), bias.clone() ) diff --git a/test/test_gemm_allreduce.py b/test/test_gemm_allreduce.py index 2a05699..1d848e0 100644 --- a/test/test_gemm_allreduce.py +++ b/test/test_gemm_allreduce.py @@ -1,10 +1,9 @@ +from datetime import timedelta import os import sys -from datetime import timedelta import torch import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, @@ -59,7 +58,7 @@ def test_gemm_all_reduce(self): b = torch.empty((K, N), dtype=torch.float32, device=self.device).normal_() # calculate result for our fused kernel - result = kraken.all_reduce_fusion.gemm_one_shot_all_reduce_fused(a, b) + result = kraken.fused.gemm_one_shot_all_reduce_fused(a, b) # expected value expected = torch.matmul(a, b) @@ -79,7 +78,7 @@ def test_gemm_all_reduce_square(self): b = torch.empty((K, N), dtype=torch.float32, device=self.device).normal_() # calculate result for our fused kernel - result = kraken.all_reduce_fusion.gemm_one_shot_all_reduce_fused(a, b) + result = kraken.fused.gemm_one_shot_all_reduce_fused(a, b) # expected value expected = torch.matmul(a, b) @@ -98,19 +97,25 @@ def test_rank_specific_values_all_reduce(self): # Each rank contributes (rank + 1) to the final sum # This makes it easy to verify all-reduce worked correctly rank_multiplier = self.rank + 1 - a = torch.ones((M, K), dtype=torch.float32, device=self.device) * rank_multiplier + a = ( + torch.ones((M, K), dtype=torch.float32, device=self.device) + * rank_multiplier + ) b = torch.ones((K, N), dtype=torch.float32, device=self.device) - result = kraken.all_reduce_fusion.gemm_one_shot_all_reduce_fused(a, b) + result = kraken.fused.gemm_one_shot_all_reduce_fused(a, b) # Expected: sum of all rank contributions # rank 0: 1*K, rank 1: 2*K, rank 2: 3*K, rank 3: 4*K # Total = K * (1+2+3+4) = K * 10 expected_sum = K * sum(range(1, self.world_size + 1)) # K * 10 = K * 10 - expected = torch.full((M, N), expected_sum, dtype=torch.float32, device=self.device) + expected = torch.full( + (M, N), expected_sum, dtype=torch.float32, device=self.device + ) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) dist.destroy_process_group() + if __name__ == "__main__": run_tests() diff --git a/test/test_gemm_reduce_scatter.py b/test/test_gemm_reduce_scatter.py index 55ab6ff..ad58c94 100644 --- a/test/test_gemm_reduce_scatter.py +++ b/test/test_gemm_reduce_scatter.py @@ -1,10 +1,9 @@ +from datetime import timedelta import os import sys -from datetime import timedelta import torch import torch.distributed as dist -import torch.distributed._symmetric_memory as symm_mem from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, @@ -16,7 +15,7 @@ # Adjust the path to import the kernel from the 'kraken' project directory sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from kraken.reduce_scatter_fusion import ( +from kraken.fused import ( gemm_reduce_scatter, gemm_reduce_scatter_ce_persistent, ) @@ -55,10 +54,9 @@ def _get_expected_result(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor group_name = dist.group.WORLD.group_name # use torch symm mem's fused_matmul_reduce_scatter impl for testing - expected_result = torch.ops.symm_mem.fused_matmul_reduce_scatter( + return torch.ops.symm_mem.fused_matmul_reduce_scatter( a, b, "sum", scatter_dim=0, group_name=group_name ) - return expected_result @skip_if_lt_x_gpu(4) def test_gemm_reduce_scatter(self):