@@ -23,8 +23,9 @@ Our initial kernels are adapted from the [Symmetric Memory Recipes](https://gith
2323## 🚀 Getting Started
2424### Prerequisites
2525- PyTorch (version 2.6.0 or higher)
26- - Triton (version 3.3.0 or higher )
26+ - Triton (version 3.3.0)
2727- Python (version 3.10 or higher)
28+ - CUDA (version 12.4 or higher) Version must matche your PyTorch installaltion.
2829
2930### Installation
3031``` bash
@@ -48,8 +49,10 @@ import torch.distributed._symmetric_memory as symm_mem
4849import kraken
4950import os
5051
51- # local_rank is needed for device placement, and can be received from the environment
52+ # setup distributed process group.
5253local_rank = int (os.environ[" LOCAL_RANK" ])
54+ torch.cuda.set_device(f " cuda: { local_rank} " )
55+ dist.init_process_group(" nccl" )
5356
5457# Create and initialize a symmetric memory tensor
5558# See blog: https://dev-discuss.pytorch.org/t/pytorch-symmetricmemory-harnessing-nvlink-programmability-with-ease/279 for symmetric memory details.
@@ -62,7 +65,13 @@ symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
6265a_shared = a_shared.normal_()
6366
6467# Call one_shot_all_reduce kernel from kraken.
65- a = kraken.one_shot_all_reduce(a_shared)
68+ a = kraken.comm.one_shot_all_reduce(a_shared)
69+ ```
70+ Remember to run with torchrun! Example torchrun command:
71+ ``` shell
72+ torchrun --nnodes 1 --nproc-per-node < world_size> \
73+ --rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python \
74+ python3 example.py
6675```
6776
6877Alternatively, you can build your own custom kernels by leveraging Kraken's low-level primitives. This allows you to create highly optimized kernels tailored to your specific needs. We provide PTX implementations of low-level primitives in ` kraken._ptx_utils ` .
@@ -102,6 +111,8 @@ def custom_distributed_kernel(
102111
103112# Create and initialize a symmetric memory tensor
104113local_rank = int (os.environ[" LOCAL_RANK" ])
114+ torch.cuda.set_device(f " cuda: { local_rank} " )
115+ dist.init_process_group(" nccl" )
105116a_shared = symm_mem.empty((4096 , 4096 ), dtype = torch.bfloat16, device = f " cuda: { local_rank} " )
106117symm_mem_hdl = symm_mem.rendezvous(a_shared, group = dist.group.WORLD )
107118
@@ -122,19 +133,22 @@ custom_distributed_kernel[grid](
122133Kraken is organized for easy hacking of distributed Triton kernel:
123134
124135### Example Kernels
125- #### ` kraken.all_gather_fusion `
126- - ` all_gather_matmul `
127- #### ` kraken.all_reduce_fusion `
128- - ` rms_norm ` ,
129- - ` gemm_one_shot_all_reduce_fused `
130- - ` one_shot_all_reduce_bias `
131- - ` one_shot_all_reduce_bias_rms_norm `
132- - ` two_shot_all_reduce_bias `
133- - ` two_shot_all_reduce_bias_rms_norm `
136+ #### ` kraken.comm `
137+ contains communication kernels with fine-grained sychronizations.
138+ - ` all_gather_w_progress `
134139- ` one_shot_all_reduce `
135- #### ` kraken.reduce_scatter_fusion `
136- - ` gemm_reduce_scatter `
137- - ` gemm_reduce_scatter_ce_persistent `
140+ - (coming soon) ` two_shot_all_reduce `
141+ - (coming soon) ` multimem_all_reduce `
142+ #### ` kraken.fused `
143+ Fused communication/computation kernels.
144+ - All gather matmul: ` all_gather_matmul `
145+ - Gemm all reduce: ` gemm_one_shot_all_reduce_fused `
146+ - Gemm reduce scatter: ` gemm_reduce_scatter ` , ` gemm_reduce_scatter_ce_persistent `
147+ - Reduce bias: ` one_shot_all_reduce_bias ` , ` two_shot_all_reduce_bias `
148+ - Reduce bias rms_norm: ` one_shot_all_reduce_bias_rms_norm ` , ` two_shot_all_reduce_bias_rms_norm `
149+
150+ #### ` kraken.quantized `
151+ (comming soon) Fused communication/computation kernels with quantization.
138152
139153
140154### Inline PTX Utils
@@ -146,10 +160,9 @@ Kraken is organized for easy hacking of distributed Triton kernel:
146160Kraken includes a set of benchmarks in ` benchmarks/ ` to evaluate the performance of its kernels. You can run them as follows:
147161
148162``` bash
149- torchrun --nnodes 1 --nproc-per-node 8 \
163+ torchrun --nnodes 1 --nproc-per-node < world_size > \
150164--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python python3 \
151- benchmark/benchmark_all_reduce.py \
152- --backend nccl,triton_1shot,dist_1shot
165+ benchmark/benchmark_all_reduce.py
153166# ... and so on for other benchmarks
154167```
155168
0 commit comments