Skip to content

Commit 8fd9355

Browse files
committed
Restruct kernels: divided into: comm kernels & fused kernels, (comming soon) quantized kernels
1 parent d631315 commit 8fd9355

31 files changed

+372
-172
lines changed

README.md

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ Our initial kernels are adapted from the [Symmetric Memory Recipes](https://gith
2323
## 🚀 Getting Started
2424
### Prerequisites
2525
- PyTorch (version 2.6.0 or higher)
26-
- Triton (version 3.3.0 or higher)
26+
- Triton (version 3.3.0)
2727
- Python (version 3.10 or higher)
28+
- CUDA (version 12.4 or higher) Version must matche your PyTorch installaltion.
2829

2930
### Installation
3031
```bash
@@ -122,19 +123,22 @@ custom_distributed_kernel[grid](
122123
Kraken is organized for easy hacking of distributed Triton kernel:
123124

124125
### Example Kernels
125-
#### `kraken.all_gather_fusion`
126-
- `all_gather_matmul`
127-
#### `kraken.all_reduce_fusion`
128-
- `rms_norm`,
129-
- `gemm_one_shot_all_reduce_fused`
130-
- `one_shot_all_reduce_bias`
131-
- `one_shot_all_reduce_bias_rms_norm`
132-
- `two_shot_all_reduce_bias`
133-
- `two_shot_all_reduce_bias_rms_norm`
126+
#### `kraken.comm`
127+
contains communication kernels with fine-grained sychronizations.
128+
- `all_gather_w_progress`
134129
- `one_shot_all_reduce`
135-
#### `kraken.reduce_scatter_fusion`
136-
- `gemm_reduce_scatter`
137-
- `gemm_reduce_scatter_ce_persistent`
130+
- (coming soon) `two_shot_all_reduce`
131+
- (coming soon) `multimem_all_reduce`
132+
#### `kraken.fused`
133+
Fused communication/computation kernels.
134+
- All gather matmul: `all_gather_matmul`
135+
- Gemm all reduce: `gemm_one_shot_all_reduce_fused`
136+
- Gemm reduce scatter: `gemm_reduce_scatter`, `gemm_reduce_scatter_ce_persistent`
137+
- Reduce bias: `one_shot_all_reduce_bias`, `two_shot_all_reduce_bias`
138+
- Reduce bias rms_norm: `one_shot_all_reduce_bias_rms_norm`, `two_shot_all_reduce_bias_rms_norm`
139+
140+
#### `kraken.quantized`
141+
(comming soon) Fused communication/computation kernels with quantization.
138142

139143

140144
### Inline PTX Utils
@@ -146,10 +150,9 @@ Kraken is organized for easy hacking of distributed Triton kernel:
146150
Kraken includes a set of benchmarks in `benchmarks/` to evaluate the performance of its kernels. You can run them as follows:
147151

148152
```bash
149-
torchrun --nnodes 1 --nproc-per-node 8 \
153+
torchrun --nnodes 1 --nproc-per-node <world_size> \
150154
--rdzv-backend c10d --rdzv-endpoint localhost:0 --no_python python3 \
151-
benchmark/benchmark_all_reduce.py \
152-
--backend nccl,triton_1shot,dist_1shot
155+
benchmark/benchmark_all_reduce.py
153156
# ... and so on for other benchmarks
154157
```
155158

benchmark/benchmark_all_gather_matmul.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import csv
44
from dataclasses import asdict, dataclass
55
import functools
6-
import itertools
76
import os
87
import sys
98

@@ -63,15 +62,10 @@ def asdict(self):
6362

6463
def generate_experiment_configs(
6564
dtype: torch.dtype,
66-
M: list[int],
67-
N: list[int],
68-
K: list[int],
65+
shapes: list[tuple[int, int, int]],
6966
backends: list[str],
7067
device: torch.device,
7168
) -> list[ExperimentConfig]:
72-
# Generate cross config shapes from M, N, K lists
73-
shapes = list(itertools.product(M, N, K))
74-
7569
all_configs = []
7670
for shape in shapes:
7771
all_configs.append(
@@ -93,7 +87,7 @@ def get_single_backend_fn(backend: str):
9387
if backend == "torch_symm_mem":
9488
return torch_symm_mem_ag_mm
9589
if backend == "triton":
96-
return kraken.all_gather_fusion.all_gather_matmul
90+
return kraken.fused.all_gather_matmul
9791
raise NotImplementedError(backend)
9892

9993

@@ -176,9 +170,7 @@ def main(args):
176170
torch.manual_seed(42 + local_rank)
177171

178172
results = []
179-
configs = generate_experiment_configs(
180-
args.dtype, args.M, args.N, args.K, args.backend, device
181-
)
173+
configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
182174
for config in configs:
183175
results.append(
184176
Experiment(
@@ -196,7 +188,7 @@ def shape_input_type(s):
196188
M, N, K = map(int, s.split(","))
197189
return M, N, K
198190
except Exception as e:
199-
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
191+
raise argparse.ArgumentTypeError("Shape must be M, N, K") from e
200192

201193

202194
if __name__ == "__main__":
@@ -228,27 +220,15 @@ def shape_input_type(s):
228220
)
229221

230222
parser.add_argument(
231-
"-M",
232-
type=shape_input_type,
233-
nargs="+",
234-
default=[2**x for x in range(7, 11)],
235-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
236-
)
237-
238-
parser.add_argument(
239-
"-N",
223+
"--shape",
240224
type=shape_input_type,
241225
nargs="+",
242-
default=[6656],
243-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
244-
)
245-
246-
parser.add_argument(
247-
"-K",
248-
type=shape_input_type,
249-
nargs="+",
250-
default=[2**x for x in range(12, 15)],
251-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
226+
default=[
227+
(m, 6656, k)
228+
for m in [2**x for x in range(7, 11)]
229+
for k in [2**x for x in range(12, 16)]
230+
],
231+
help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
252232
)
253233

254234
parser.add_argument("-dtype", type=str, help="dtype", default="bfloat16")

benchmark/benchmark_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_single_backend_fn(backend: str):
114114
if backend == "dist_2shot":
115115
return symm_mem_two_shot_all_reduce
116116
if backend == "triton_1shot":
117-
return kraken.all_reduce_fusion.one_shot_all_reduce
117+
return kraken.comm.one_shot_all_reduce
118118
if backend == "nccl":
119119
return nccl_ring
120120
raise NotImplementedError(backend)

benchmark/benchmark_all_reduce_bias.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ def one_shot_all_reduce_bias(
1515
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
1616
) -> torch.Tensor:
1717
y = torch.empty_like(x)
18-
kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
18+
kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
1919
return y
2020

2121

2222
def two_shot_all_reduce_bias(
2323
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
2424
) -> torch.Tensor:
2525
y = torch.empty_like(x)
26-
kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
26+
kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
2727
return y
2828

2929

benchmark/benchmark_all_reduce_bias_rms_norm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,36 @@
1313

1414
def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
1515
y = torch.empty_like(x)
16-
kraken.all_reduce_fusion.one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
16+
kraken.fused.one_shot_all_reduce_bias_rms_norm(
17+
symm_mem_input, x, bias, rms_weight, y
18+
)
1719
return y
1820

1921

2022
def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
2123
y = torch.empty_like(x)
22-
kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
23-
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
24+
kraken.fused.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
25+
return kraken.fused.rms_norm(y, rms_weight)
2426

2527

2628
def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
2729
y = torch.empty_like(x)
28-
kraken.all_reduce_fusion.two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
30+
kraken.fused.two_shot_all_reduce_bias_rms_norm(
31+
symm_mem_input, x, bias, rms_weight, y
32+
)
2933
return y
3034

3135

3236
def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
3337
y = torch.empty_like(x)
34-
kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
35-
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
38+
kraken.fused.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
39+
return kraken.fused.rms_norm(y, rms_weight)
3640

3741

3842
def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight):
3943
dist.all_reduce(x)
4044
y = x + bias
41-
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
45+
return kraken.fused.rms_norm(y, rms_weight)
4246

4347

4448
def create_benchmarks(b, t, d_size, device, dtype):

benchmark/benchmark_matmul_reduce_scatter.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import csv
44
from dataclasses import asdict, dataclass
55
import functools
6-
import itertools
76
import os
87
import sys
98

@@ -68,15 +67,10 @@ def asdict(self):
6867

6968
def generate_experiment_configs(
7069
dtype: torch.dtype,
71-
M: list[int],
72-
N: list[int],
73-
K: list[int],
70+
shapes: list[tuple[int, int, int]],
7471
backends: list[str],
7572
device: torch.device,
7673
) -> list[ExperimentConfig]:
77-
# Generate cross config shapes from M, N, K lists
78-
shapes = list(itertools.product(M, N, K))
79-
8074
all_configs = []
8175
for shape in shapes:
8276
all_configs.append(
@@ -98,7 +92,7 @@ def get_single_backend_fn(backend: str):
9892
if backend == "torch_symm_mem":
9993
return torch_symm_mem_gemm_rs
10094
if backend == "triton":
101-
return kraken.reduce_scatter_fusion.gemm_reduce_scatter
95+
return kraken.fused.gemm_reduce_scatter
10296
raise NotImplementedError(backend)
10397

10498

@@ -181,9 +175,7 @@ def main(args):
181175
torch.manual_seed(42 + local_rank)
182176

183177
results = []
184-
configs = generate_experiment_configs(
185-
args.dtype, args.M, args.N, args.K, args.backend, device
186-
)
178+
configs = generate_experiment_configs(args.dtype, args.shape, args.backend, device)
187179
for config in configs:
188180
results.append(
189181
Experiment(
@@ -201,7 +193,7 @@ def shape_input_type(s):
201193
M, N, K = map(int, s.split(","))
202194
return M, N, K
203195
except Exception as e:
204-
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
196+
raise argparse.ArgumentTypeError("Shape must be M, N, K") from e
205197

206198

207199
if __name__ == "__main__":
@@ -233,27 +225,15 @@ def shape_input_type(s):
233225
)
234226

235227
parser.add_argument(
236-
"-M",
237-
type=shape_input_type,
238-
nargs="+",
239-
default=[2**x for x in range(7, 11)],
240-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
241-
)
242-
243-
parser.add_argument(
244-
"-N",
228+
"--shape",
245229
type=shape_input_type,
246230
nargs="+",
247-
default=[6656],
248-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
249-
)
250-
251-
parser.add_argument(
252-
"-K",
253-
type=shape_input_type,
254-
nargs="+",
255-
default=[2**x for x in range(12, 16)],
256-
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
231+
default=[
232+
(m, 6656, k)
233+
for m in [2**x for x in range(7, 11)]
234+
for k in [2**x for x in range(12, 16)]
235+
],
236+
help="matmul shapes: M, N, K. (M, K) @ (K, N) -> (M, N)",
257237
)
258238

259239
parser.add_argument("-dtype", type=str, help="dtype", default="float32")

kraken/__init__.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from . import (
22
_logging,
3-
all_gather_fusion,
4-
all_reduce,
5-
all_reduce_fusion,
6-
reduce_scatter_fusion,
3+
_ptx_utils,
4+
comm,
5+
fused,
76
)
87

98
__all__ = [
109
"_logging",
11-
"all_gather_fusion",
12-
"all_reduce",
13-
"all_reduce_fusion",
14-
"reduce_scatter_fusion",
10+
"_ptx_utils",
11+
"comm",
12+
"fused",
1513
]

kraken/all_gather_fusion/__init__.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

kraken/all_gather_fusion/copy_engine_all_gather.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)