Skip to content

Commit a861a2d

Browse files
committed
Restruct kernels: divided into: comm kernels & fused kernels, (comming soon) quantized kernels
1 parent d631315 commit a861a2d

32 files changed

+399
-189
lines changed

README.md

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ Our initial kernels are adapted from the [Symmetric Memory Recipes](https://gith
2323
## 🚀 Getting Started
2424
### Prerequisites
2525
- PyTorch (version 2.6.0 or higher)
26-
- Triton (version 3.3.0 or higher)
26+
- Triton (version 3.3.0)
2727
- Python (version 3.10 or higher)
28+
- CUDA (version 12.4 or higher) Version must matche your PyTorch installaltion.
2829

2930
### Installation
3031
```bash
@@ -48,8 +49,10 @@ import torch.distributed._symmetric_memory as symm_mem
4849
import kraken
4950
import os
5051

51-
# local_rank is needed for device placement, and can be received from the environment
52+
# setup distributed process group.
5253
local_rank = int(os.environ["LOCAL_RANK"])
54+
torch.cuda.set_device(f"cuda:{local_rank}")
55+
dist.init_process_group("nccl")
5356

5457
# Create and initialize a symmetric memory tensor
5558
# See blog: https://dev-discuss.pytorch.org/t/pytorch-symmetricmemory-harnessing-nvlink-programmability-with-ease/279 for symmetric memory details.
@@ -62,7 +65,13 @@ symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
6265
a_shared = a_shared.normal_()
6366

6467
# Call one_shot_all_reduce kernel from kraken.
65-
a = kraken.one_shot_all_reduce(a_shared)
68+
a = kraken.comm.one_shot_all_reduce(a_shared)
69+
```
70+
Remember to run with torchrun! Example torchrun command:
71+
```shell
72+
torchrun --nnodes 1 --nproc-per-node <world_size> \
73+
--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python \
74+
python3 example.py
6675
```
6776

6877
Alternatively, you can build your own custom kernels by leveraging Kraken's low-level primitives. This allows you to create highly optimized kernels tailored to your specific needs. We provide PTX implementations of low-level primitives in `kraken._ptx_utils`.
@@ -102,6 +111,8 @@ def custom_distributed_kernel(
102111

103112
# Create and initialize a symmetric memory tensor
104113
local_rank = int(os.environ["LOCAL_RANK"])
114+
torch.cuda.set_device(f"cuda:{local_rank}")
115+
dist.init_process_group("nccl")
105116
a_shared = symm_mem.empty((4096, 4096), dtype=torch.bfloat16, device=f"cuda:{local_rank}")
106117
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
107118

@@ -122,19 +133,22 @@ custom_distributed_kernel[grid](
122133
Kraken is organized for easy hacking of distributed Triton kernel:
123134

124135
### Example Kernels
125-
#### `kraken.all_gather_fusion`
126-
- `all_gather_matmul`
127-
#### `kraken.all_reduce_fusion`
128-
- `rms_norm`,
129-
- `gemm_one_shot_all_reduce_fused`
130-
- `one_shot_all_reduce_bias`
131-
- `one_shot_all_reduce_bias_rms_norm`
132-
- `two_shot_all_reduce_bias`
133-
- `two_shot_all_reduce_bias_rms_norm`
136+
#### `kraken.comm`
137+
contains communication kernels with fine-grained sychronizations.
138+
- `all_gather_w_progress`
134139
- `one_shot_all_reduce`
135-
#### `kraken.reduce_scatter_fusion`
136-
- `gemm_reduce_scatter`
137-
- `gemm_reduce_scatter_ce_persistent`
140+
- (coming soon) `two_shot_all_reduce`
141+
- (coming soon) `multimem_all_reduce`
142+
#### `kraken.fused`
143+
Fused communication/computation kernels.
144+
- All gather matmul: `all_gather_matmul`
145+
- Gemm all reduce: `gemm_one_shot_all_reduce_fused`
146+
- Gemm reduce scatter: `gemm_reduce_scatter`, `gemm_reduce_scatter_ce_persistent`
147+
- Reduce bias: `one_shot_all_reduce_bias`, `two_shot_all_reduce_bias`
148+
- Reduce bias rms_norm: `one_shot_all_reduce_bias_rms_norm`, `two_shot_all_reduce_bias_rms_norm`
149+
150+
#### `kraken.quantized`
151+
(comming soon) Fused communication/computation kernels with quantization.
138152

139153

140154
### Inline PTX Utils
@@ -146,10 +160,9 @@ Kraken is organized for easy hacking of distributed Triton kernel:
146160
Kraken includes a set of benchmarks in `benchmarks/` to evaluate the performance of its kernels. You can run them as follows:
147161

148162
```bash
149-
torchrun --nnodes 1 --nproc-per-node 8 \
163+
torchrun --nnodes 1 --nproc-per-node <world_size> \
150164
--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python python3 \
151-
benchmark/benchmark_all_reduce.py \
152-
--backend nccl,triton_1shot,dist_1shot
165+
benchmark/benchmark_all_reduce.py
153166
# ... and so on for other benchmarks
154167
```
155168

benchmark/benchmark_all_gather_matmul.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import csv
44
from dataclasses import asdict, dataclass
55
import functools
6-
import itertools
76
import os
87
import sys
98

@@ -63,15 +62,10 @@ def asdict(self):
6362

6463
def generate_experiment_configs(
6564
dtype: torch.dtype,
66-
M: list[int],
67-
N: list[int],
68-
K: list[int],
65+
shapes: list[tuple[int, int, int]],
6966
backends: list[str],
7067
device: torch.device,
7168
) -> list[ExperimentConfig]:
72-
# Generate cross config shapes from M, N, K lists
73-
shapes = list(itertools.product(M, N, K))
74-
7569
all_configs = []
7670
for shape in shapes:
7771
all_configs.append(
@@ -93,7 +87,7 @@ def get_single_backend_fn(backend: str):
9387
if backend == "torch_symm_mem":
9488
return torch_symm_mem_ag_mm
9589
if backend == "triton":
96-
return kraken.all_gather_fusion.all_gather_matmul
90+
return kraken.fused.all_gather_matmul
9791
raise NotImplementedError(backend)
9892

9993

@@ -176,9 +170,7 @@ def main(args):
176170
torch.manual_seed(42 + local_rank)
177171

178172
results = []
179-
configs = generate_experiment_configs(
180-
args.dtype, args.M, args.N, args.K, args.backend, device
181-
)
173+
configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
182174
for config in configs:
183175
results.append(
184176
Experiment(
@@ -196,7 +188,7 @@ def shape_input_type(s):
196188
M, N, K = map(int, s.split(","))
197189
return M, N, K
198190
except Exception as e:
199-
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
191+
raise argparse.ArgumentTypeError("Shape must be M, N, K") from e
200192

201193

202194
if __name__ == "__main__":
@@ -228,27 +220,15 @@ def shape_input_type(s):
228220
)
229221

230222
parser.add_argument(
231-
"-M",
232-
type=shape_input_type,
233-
nargs="+",
234-
default=[2**x for x in range(7, 11)],
235-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
236-
)
237-
238-
parser.add_argument(
239-
"-N",
223+
"--shape",
240224
type=shape_input_type,
241225
nargs="+",
242-
default=[6656],
243-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
244-
)
245-
246-
parser.add_argument(
247-
"-K",
248-
type=shape_input_type,
249-
nargs="+",
250-
default=[2**x for x in range(12, 15)],
251-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
226+
default=[
227+
(m, 6656, k)
228+
for m in [2**x for x in range(7, 11)]
229+
for k in [2**x for x in range(12, 16)]
230+
],
231+
help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
252232
)
253233

254234
parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")

benchmark/benchmark_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_single_backend_fn(backend: str):
114114
if backend == "dist_2shot":
115115
return symm_mem_two_shot_all_reduce
116116
if backend == "triton_1shot":
117-
return kraken.all_reduce_fusion.one_shot_all_reduce
117+
return kraken.comm.one_shot_all_reduce
118118
if backend == "nccl":
119119
return nccl_ring
120120
raise NotImplementedError(backend)

benchmark/benchmark_all_reduce_bias.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ def one_shot_all_reduce_bias(
1515
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
1616
) -> torch.Tensor:
1717
y = torch.empty_like(x)
18-
kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
18+
kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
1919
return y
2020

2121

2222
def two_shot_all_reduce_bias(
2323
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
2424
) -> torch.Tensor:
2525
y = torch.empty_like(x)
26-
kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
26+
kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
2727
return y
2828

2929

benchmark/benchmark_all_reduce_bias_rms_norm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,36 @@
1313

1414
def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
1515
y = torch.empty_like(x)
16-
kraken.all_reduce_fusion.one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
16+
kraken.fused.one_shot_all_reduce_bias_rms_norm(
17+
symm_mem_input, x, bias, rms_weight, y
18+
)
1719
return y
1820

1921

2022
def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
2123
y = torch.empty_like(x)
22-
kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
23-
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
24+
kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
25+
return kraken.fused.rms_norm(y, rms_weight)
2426

2527

2628
def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
2729
y = torch.empty_like(x)
28-
kraken.all_reduce_fusion.two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
30+
kraken.fused.two_shot_all_reduce_bias_rms_norm(
31+
symm_mem_input, x, bias, rms_weight, y
32+
)
2933
return y
3034

3135

3236
def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
3337
y = torch.empty_like(x)
34-
kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
35-
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
38+
kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
39+
return kraken.fused.rms_norm(y, rms_weight)
3640

3741

3842
def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight):
3943
dist.all_reduce(x)
4044
y = x + bias
41-
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
45+
return kraken.fused.rms_norm(y, rms_weight)
4246

4347

4448
def create_benchmarks(b, t, d_size, device, dtype):

benchmark/benchmark_matmul_reduce_scatter.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import csv
44
from dataclasses import asdict, dataclass
55
import functools
6-
import itertools
76
import os
87
import sys
98

@@ -68,15 +67,10 @@ def asdict(self):
6867

6968
def generate_experiment_configs(
7069
dtype: torch.dtype,
71-
M: list[int],
72-
N: list[int],
73-
K: list[int],
70+
shapes: list[tuple[int, int, int]],
7471
backends: list[str],
7572
device: torch.device,
7673
) -> list[ExperimentConfig]:
77-
# Generate cross config shapes from M, N, K lists
78-
shapes = list(itertools.product(M, N, K))
79-
8074
all_configs = []
8175
for shape in shapes:
8276
all_configs.append(
@@ -98,7 +92,7 @@ def get_single_backend_fn(backend: str):
9892
if backend == "torch_symm_mem":
9993
return torch_symm_mem_gemm_rs
10094
if backend == "triton":
101-
return kraken.reduce_scatter_fusion.gemm_reduce_scatter
95+
return kraken.fused.gemm_reduce_scatter
10296
raise NotImplementedError(backend)
10397

10498

@@ -181,9 +175,7 @@ def main(args):
181175
torch.manual_seed(42 + local_rank)
182176

183177
results = []
184-
configs = generate_experiment_configs(
185-
args.dtype, args.M, args.N, args.K, args.backend, device
186-
)
178+
configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
187179
for config in configs:
188180
results.append(
189181
Experiment(
@@ -201,7 +193,7 @@ def shape_input_type(s):
201193
M, N, K = map(int, s.split(","))
202194
return M, N, K
203195
except Exception as e:
204-
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
196+
raise argparse.ArgumentTypeError("Shape must be M, N, K") from e
205197

206198

207199
if __name__ == "__main__":
@@ -233,27 +225,15 @@ def shape_input_type(s):
233225
)
234226

235227
parser.add_argument(
236-
"-M",
237-
type=shape_input_type,
238-
nargs="+",
239-
default=[2**x for x in range(7, 11)],
240-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
241-
)
242-
243-
parser.add_argument(
244-
"-N",
228+
"--shape",
245229
type=shape_input_type,
246230
nargs="+",
247-
default=[6656],
248-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
249-
)
250-
251-
parser.add_argument(
252-
"-K",
253-
type=shape_input_type,
254-
nargs="+",
255-
default=[2**x for x in range(12, 16)],
256-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
231+
default=[
232+
(m, 6656, k)
233+
for m in [2**x for x in range(7, 11)]
234+
for k in [2**x for x in range(12, 16)]
235+
],
236+
help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
257237
)
258238

259239
parser.add_argument("-dtype", type=str, help="dtype", default="float32")

kraken/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from . import (
22
_logging,
3-
all_gather_fusion,
4-
all_reduce,
5-
all_reduce_fusion,
6-
reduce_scatter_fusion,
3+
_ptx_utils,
4+
comm,
5+
fused,
76
)
87

98
__all__ = [
109
"_logging",
11-
"all_gather_fusion",
12-
"all_reduce",
13-
"all_reduce_fusion",
14-
"reduce_scatter_fusion",
10+
"_ptx_utils",
11+
"comm",
12+
"fused",
1513
]

0 commit comments

Comments
 (0)