Skip to content

Commit 5f89ca3

Browse files
committed
lint and fix file structure
1 parent 9557d21 commit 5f89ca3

22 files changed

+81
-73
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Kraken is organized for easy hacking of distributed Triton kernel:
8282

8383

8484
### Inline PTX Utils
85-
85+
`kraken._ptx_utils` provides inline ptx implementation of memory barrier synchorinzations that are not natively supported by triton.
8686

8787

8888

benchmark/benchmark_all_gather_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_single_backend_fn(backend: str):
9393
if backend == "torch_symm_mem":
9494
return torch_symm_mem_ag_mm
9595
if backend == "triton":
96-
return kraken.all_gather.all_gather_matmul
96+
return kraken.all_gather_fusion.all_gather_matmul
9797
raise NotImplementedError(backend)
9898

9999

benchmark/benchmark_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_single_backend_fn(backend: str):
114114
if backend == "dist_2shot":
115115
return symm_mem_two_shot_all_reduce
116116
if backend == "triton_1shot":
117-
return kraken.all_reduce.one_shot_all_reduce
117+
return kraken.all_reduce_fusion.one_shot_all_reduce
118118
if backend == "nccl":
119119
return nccl_ring
120120
raise NotImplementedError(backend)

benchmark/benchmark_all_reduce_bias.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,23 @@
77
import torch.distributed as dist
88
import torch.distributed._symmetric_memory as symm_mem
99

10+
import kraken
1011
from kraken import _logging as log
11-
from kraken.all_reduce_fusion import (
12-
one_shot_all_reduce_bias,
13-
two_shot_all_reduce_bias,
14-
)
1512

1613

1714
def one_shot_all_reduce_bias(
1815
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
1916
) -> torch.Tensor:
2017
y = torch.empty_like(x)
21-
one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
18+
kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
2219
return y
2320

2421

2522
def two_shot_all_reduce_bias(
2623
x: torch.Tensor, bias: torch.Tensor, symm_mem_input: torch.Tensor
2724
) -> torch.Tensor:
2825
y = torch.empty_like(x)
29-
two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
26+
kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
3027
return y
3128

3229

benchmark/benchmark_all_reduce_bias_rms_norm.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,38 @@
77
import torch.distributed as dist
88
import torch.distributed._symmetric_memory as symm_mem
99

10+
import kraken
1011
from kraken import _logging as log
11-
from kraken.all_reduce_fusion import (
12-
rms_norm,
13-
one_shot_all_reduce_bias,
14-
one_shot_all_reduce_bias_rms_norm,
15-
two_shot_all_reduce_bias,
16-
two_shot_all_reduce_bias_rms_norm,
17-
)
1812

1913

2014
def one_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
2115
y = torch.empty_like(x)
22-
one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
16+
kraken.all_reduce_fusion.one_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
2317
return y
2418

2519

2620
def one_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
2721
y = torch.empty_like(x)
28-
one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
29-
return rms_norm(y, rms_weight)
22+
kraken.all_reduce_fusion.one_shot_all_reduce_bias(symm_mem_input, x, bias, y)
23+
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
3024

3125

3226
def two_shot_all_reduce_bias_rms_norm(x, bias, rms_weight, symm_mem_input):
3327
y = torch.empty_like(x)
34-
two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
28+
kraken.all_reduce_fusion.two_shot_all_reduce_bias_rms_norm(symm_mem_input, x, bias, rms_weight, y)
3529
return y
3630

3731

3832
def two_shot_all_reduce_bias_with_rms_norm(x, bias, rms_weight, symm_mem_input):
3933
y = torch.empty_like(x)
40-
two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
41-
return rms_norm(y, rms_weight)
34+
kraken.all_reduce_fusion.two_shot_all_reduce_bias(symm_mem_input, x, bias, y)
35+
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
4236

4337

4438
def nccl_all_reduce_bias_rms_norm(x, bias, rms_weight):
4539
dist.all_reduce(x)
4640
y = x + bias
47-
return rms_norm(y, rms_weight)
41+
return kraken.all_reduce_fusion.rms_norm(y, rms_weight)
4842

4943

5044
def create_benchmarks(b, t, d_size, device, dtype):

benchmark/benchmark_matmul_reduce_scatter.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import argparse
2+
from collections import defaultdict
23
import csv
4+
from dataclasses import asdict, dataclass
35
import functools
46
import itertools
57
import os
68
import sys
7-
from collections import defaultdict
8-
from dataclasses import asdict, dataclass
99

10+
from tabulate import tabulate
1011
import torch
1112
import torch.distributed as dist
1213
import torch.distributed._symmetric_memory as symm_mem
13-
from tabulate import tabulate
1414

1515
# Add the kraken directory to the Python path
1616
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -20,18 +20,22 @@
2020

2121

2222
def torch_symm_mem_gemm_rs(a, b):
23-
gemm_rs_output = torch.ops.symm_mem.fused_matmul_reduce_scatter(
23+
return torch.ops.symm_mem.fused_matmul_reduce_scatter(
2424
a, b, "sum", scatter_dim=0, group_name=dist.group.WORLD.group_name
2525
)
26-
return gemm_rs_output
26+
2727

2828
def nccl_mem_gemm_rs(a, b):
29-
from torch.distributed._functional_collectives import reduce_scatter_tensor, wait_tensor
29+
from torch.distributed._functional_collectives import (
30+
reduce_scatter_tensor,
31+
wait_tensor,
32+
)
3033

3134
gemm_output = torch.matmul(a, b)
32-
rs_o = reduce_scatter_tensor(gemm_output, "sum", scatter_dim=0, group=dist.group.WORLD)
33-
gemm_rs_output = wait_tensor(rs_o)
34-
return gemm_rs_output
35+
rs_o = reduce_scatter_tensor(
36+
gemm_output, "sum", scatter_dim=0, group=dist.group.WORLD
37+
)
38+
return wait_tensor(rs_o)
3539

3640

3741
@dataclass(frozen=True)

kraken/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1-
from . import _logging, all_gather, all_reduce, all_reduce_fusion, reduce_scatter_fusion
1+
from . import (
2+
_logging,
3+
all_gather_fusion,
4+
all_reduce,
5+
all_reduce_fusion,
6+
reduce_scatter_fusion,
7+
)
28

3-
__all__ = ["_logging", "all_gather", "all_reduce", "all_reduce_fusion", "reduce_scatter_fusion"]
9+
__all__ = [
10+
"_logging",
11+
"all_gather_fusion",
12+
"all_reduce",
13+
"all_reduce_fusion",
14+
"reduce_scatter_fusion",
15+
]

kraken/_ptx_utils/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from .gmem_barrier_arrive_wait import arrive_gmem_barrier, wait_gmem_barrier
22
from .symm_mem_barrier import (
33
_get_flat_tid as get_flat_tid,
4+
)
5+
from .symm_mem_barrier import (
46
_send_signal as send_signal,
7+
)
8+
from .symm_mem_barrier import (
59
symm_mem_sync as symm_mem_sync,
610
)
711

812
__all__ = [
913
"arrive_gmem_barrier",
10-
"symm_mem_sync",
11-
"wait_gmem_barrier",
1214
"get_flat_tid",
1315
"send_signal",
16+
"symm_mem_sync",
17+
"wait_gmem_barrier",
1418
]
1519
# Avoid ptx_utils when possible
File renamed without changes.

0 commit comments

Comments
 (0)